From bafc367c9b6dcd770c6c60c7d992466eb697ee7f Mon Sep 17 00:00:00 2001 From: xirothedev Date: Sun, 15 Mar 2026 16:16:02 +0700 Subject: [PATCH 1/7] feat(proxy): persist response chains for previous_response_id Add durable response snapshots and replay previous_response_id across HTTP and websocket responses. Prefer the original serving account when it is still eligible, add migration coverage, and sync the OpenSpec change and main specs. --- app/core/openai/requests.py | 7 - .../20260315_120000_add_response_snapshots.py | 58 ++++ app/db/models.py | 15 + app/dependencies.py | 2 + app/modules/proxy/load_balancer.py | 14 + app/modules/proxy/repo_bundle.py | 2 + app/modules/proxy/request_policy.py | 7 + .../proxy/response_snapshots_repository.py | 88 +++++ app/modules/proxy/service.py | 301 +++++++++++++++++- .../design.md | 24 ++ .../proposal.md | 30 ++ .../specs/database-migrations/spec.md | 9 + .../specs/responses-api-compat/spec.md | 24 ++ .../tasks.md | 24 ++ .../specs/responses-api-compat/context.md | 5 +- openspec/specs/responses-api-compat/spec.md | 20 +- .../test_load_balancer_integration.py | 114 +++++++ tests/integration/test_migrations.py | 25 ++ .../test_openai_compat_features.py | 59 +++- .../test_proxy_websocket_responses.py | 166 ++++++++++ tests/unit/test_openai_requests.py | 15 +- tests/unit/test_proxy_utils.py | 104 +++++- uv.lock | 2 +- 23 files changed, 1073 insertions(+), 42 deletions(-) create mode 100644 app/db/alembic/versions/20260315_120000_add_response_snapshots.py create mode 100644 app/modules/proxy/response_snapshots_repository.py create mode 100644 openspec/changes/support-previous-response-id-persistence/design.md create mode 100644 openspec/changes/support-previous-response-id-persistence/proposal.md create mode 100644 openspec/changes/support-previous-response-id-persistence/specs/database-migrations/spec.md create mode 100644 openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md create mode 100644 openspec/changes/support-previous-response-id-persistence/tasks.md diff --git a/app/core/openai/requests.py b/app/core/openai/requests.py index 3c179be5..85ab82c5 100644 --- a/app/core/openai/requests.py +++ b/app/core/openai/requests.py @@ -366,13 +366,6 @@ def _ensure_store_false(cls, value: bool | None) -> bool: raise ValueError("store must be false") return False if value is None else value - @field_validator("previous_response_id") - @classmethod - def _reject_previous_response_id(cls, value: str | None) -> str | None: - if value is None: - return value - raise ValueError("previous_response_id is not supported") - @field_validator("tools") @classmethod def _validate_tools(cls, value: list[JsonValue]) -> list[JsonValue]: diff --git a/app/db/alembic/versions/20260315_120000_add_response_snapshots.py b/app/db/alembic/versions/20260315_120000_add_response_snapshots.py new file mode 100644 index 00000000..21c88514 --- /dev/null +++ b/app/db/alembic/versions/20260315_120000_add_response_snapshots.py @@ -0,0 +1,58 @@ +"""add durable response snapshots + +Revision ID: 20260315_120000_add_response_snapshots +Revises: 20260312_120000_add_dashboard_upstream_stream_transport +Create Date: 2026-03-15 12:00:00.000000 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.engine import Connection + +# revision identifiers, used by Alembic. +revision = "20260315_120000_add_response_snapshots" +down_revision = "20260312_120000_add_dashboard_upstream_stream_transport" +branch_labels = None +depends_on = None + + +def _table_exists(connection: Connection, table_name: str) -> bool: + inspector = sa.inspect(connection) + return inspector.has_table(table_name) + + +def _indexes(connection: Connection, table_name: str) -> set[str]: + inspector = sa.inspect(connection) + if not inspector.has_table(table_name): + return set() + return {str(index["name"]) for index in inspector.get_indexes(table_name) if index.get("name") is not None} + + +def upgrade() -> None: + bind = op.get_bind() + if not _table_exists(bind, "response_snapshots"): + op.create_table( + "response_snapshots", + sa.Column("response_id", sa.String(), nullable=False), + sa.Column("parent_response_id", sa.String(), nullable=True), + sa.Column("account_id", sa.String(), nullable=True), + sa.Column("model", sa.String(), nullable=False), + sa.Column("input_items_json", sa.Text(), nullable=False), + sa.Column("response_json", sa.Text(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), + sa.PrimaryKeyConstraint("response_id"), + ) + existing_indexes = _indexes(bind, "response_snapshots") + if "idx_response_snapshots_parent_created_at" not in existing_indexes: + op.create_index( + "idx_response_snapshots_parent_created_at", + "response_snapshots", + ["parent_response_id", "created_at"], + unique=False, + ) + + +def downgrade() -> None: + return diff --git a/app/db/models.py b/app/db/models.py index 2d434165..73d6ee42 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -129,6 +129,21 @@ class RequestLog(Base): error_message: Mapped[str | None] = mapped_column(Text, nullable=True) +class ResponseSnapshot(Base): + __tablename__ = "response_snapshots" + __table_args__ = ( + Index("idx_response_snapshots_parent_created_at", "parent_response_id", "created_at"), + ) + + response_id: Mapped[str] = mapped_column(String, primary_key=True) + parent_response_id: Mapped[str | None] = mapped_column(String, nullable=True) + account_id: Mapped[str | None] = mapped_column(String, nullable=True) + model: Mapped[str] = mapped_column(String, nullable=False) + input_items_json: Mapped[str] = mapped_column(Text, nullable=False) + response_json: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False) + + class StickySession(Base): __tablename__ = "sticky_sessions" diff --git a/app/dependencies.py b/app/dependencies.py index c25ef49c..8da59b45 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -20,6 +20,7 @@ from app.modules.firewall.service import FirewallService from app.modules.oauth.service import OauthService from app.modules.proxy.repo_bundle import ProxyRepositories +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository from app.modules.proxy.service import ProxyService from app.modules.proxy.sticky_repository import StickySessionsRepository from app.modules.request_logs.repository import RequestLogsRepository @@ -153,6 +154,7 @@ async def _proxy_repo_context() -> AsyncIterator[ProxyRepositories]: sticky_sessions=StickySessionsRepository(session), api_keys=ApiKeysRepository(session), additional_usage=AdditionalUsageRepository(session), + response_snapshots=ResponseSnapshotsRepository(session), ) diff --git a/app/modules/proxy/load_balancer.py b/app/modules/proxy/load_balancer.py index 32a4c569..8169f377 100644 --- a/app/modules/proxy/load_balancer.py +++ b/app/modules/proxy/load_balancer.py @@ -82,6 +82,7 @@ async def select_account( model: str | None = None, additional_limit_name: str | None = None, exclude_account_ids: Collection[str] | None = None, + preferred_account_id: str | None = None, ) -> AccountSelection: selection_inputs = await self._load_selection_inputs( model=model, @@ -125,6 +126,7 @@ async def select_account( sticky_max_age_seconds=sticky_max_age_seconds, prefer_earlier_reset_accounts=prefer_earlier_reset_accounts, routing_strategy=routing_strategy, + preferred_account_id=preferred_account_id, sticky_repo=repos.sticky_sessions, ) if result.account is not None: @@ -338,8 +340,20 @@ async def _select_with_stickiness( sticky_max_age_seconds: int | None, prefer_earlier_reset_accounts: bool, routing_strategy: RoutingStrategy, + preferred_account_id: str | None, sticky_repo: StickySessionsRepository | None, ) -> SelectionResult: + if preferred_account_id: + preferred_state = next((state for state in states if state.account_id == preferred_account_id), None) + if preferred_state is not None: + preferred_result = select_account( + [preferred_state], + prefer_earlier_reset=prefer_earlier_reset_accounts, + routing_strategy=routing_strategy, + allow_backoff_fallback=False, + ) + if preferred_result.account is not None: + return preferred_result if not sticky_key or not sticky_repo: return select_account( states, diff --git a/app/modules/proxy/repo_bundle.py b/app/modules/proxy/repo_bundle.py index afa6508f..4d102429 100644 --- a/app/modules/proxy/repo_bundle.py +++ b/app/modules/proxy/repo_bundle.py @@ -6,6 +6,7 @@ from app.modules.accounts.repository import AccountsRepository from app.modules.api_keys.repository import ApiKeysRepository +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository from app.modules.proxy.sticky_repository import StickySessionsRepository from app.modules.request_logs.repository import RequestLogsRepository from app.modules.usage.repository import AdditionalUsageRepository, UsageRepository @@ -19,6 +20,7 @@ class ProxyRepositories: sticky_sessions: StickySessionsRepository api_keys: ApiKeysRepository additional_usage: AdditionalUsageRepository + response_snapshots: ResponseSnapshotsRepository | None = None ProxyRepoFactory = Callable[[], AsyncContextManager[ProxyRepositories]] diff --git a/app/modules/proxy/request_policy.py b/app/modules/proxy/request_policy.py index 52378bdd..399c1421 100644 --- a/app/modules/proxy/request_policy.py +++ b/app/modules/proxy/request_policy.py @@ -78,6 +78,13 @@ def openai_invalid_payload_error(param: str | None = None) -> OpenAIErrorEnvelop return error +def openai_invalid_request_error(message: str, *, param: str | None = None) -> OpenAIErrorEnvelope: + error = openai_error("invalid_request_error", message, error_type="invalid_request_error") + if param: + error["error"]["param"] = param + return error + + def normalize_responses_request_payload( payload: dict[str, JsonValue], *, diff --git a/app/modules/proxy/response_snapshots_repository.py b/app/modules/proxy/response_snapshots_repository.py new file mode 100644 index 00000000..9665cc5b --- /dev/null +++ b/app/modules/proxy/response_snapshots_repository.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import json + +from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import Insert, func + +from app.core.types import JsonValue +from app.db.models import ResponseSnapshot + + +class ResponseSnapshotsRepository: + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get(self, response_id: str) -> ResponseSnapshot | None: + if not response_id: + return None + result = await self._session.execute( + select(ResponseSnapshot).where(ResponseSnapshot.response_id == response_id) + ) + return result.scalar_one_or_none() + + async def upsert( + self, + *, + response_id: str, + parent_response_id: str | None, + account_id: str | None, + model: str, + input_items: list[JsonValue], + response_payload: dict[str, JsonValue], + ) -> ResponseSnapshot: + statement = self._build_upsert_statement( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id, + model=model, + input_items_json=json.dumps(input_items, ensure_ascii=False, separators=(",", ":")), + response_json=json.dumps(response_payload, ensure_ascii=False, separators=(",", ":")), + ) + await self._session.execute(statement) + await self._session.commit() + snapshot = await self.get(response_id) + if snapshot is None: + raise RuntimeError(f"ResponseSnapshot upsert failed for response_id={response_id!r}") + await self._session.refresh(snapshot) + return snapshot + + def _build_upsert_statement( + self, + *, + response_id: str, + parent_response_id: str | None, + account_id: str | None, + model: str, + input_items_json: str, + response_json: str, + ) -> Insert: + dialect = self._session.get_bind().dialect.name + if dialect == "postgresql": + insert_fn = pg_insert + elif dialect == "sqlite": + insert_fn = sqlite_insert + else: + raise RuntimeError(f"ResponseSnapshot upsert unsupported for dialect={dialect!r}") + statement = insert_fn(ResponseSnapshot).values( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id, + model=model, + input_items_json=input_items_json, + response_json=response_json, + ) + return statement.on_conflict_do_update( + index_elements=[ResponseSnapshot.response_id], + set_={ + "parent_response_id": parent_response_id, + "account_id": account_id, + "model": model, + "input_items_json": input_items_json, + "response_json": response_json, + "created_at": func.now(), + }, + ) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 1fad9f92..c851d422 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -7,7 +7,7 @@ import time from collections import deque from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from hashlib import sha256 from typing import AsyncIterator, Mapping, NoReturn from uuid import uuid4 @@ -51,8 +51,8 @@ from app.core.exceptions import AppError, ProxyAuthError, ProxyRateLimitError from app.core.openai.exceptions import ClientPayloadError from app.core.openai.models import CompactResponsePayload, OpenAIEvent, OpenAIResponsePayload -from app.core.openai.parsing import parse_sse_event -from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest +from app.core.openai.parsing import parse_response_payload, parse_sse_event +from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest, _sanitize_input_items from app.core.types import JsonValue from app.core.usage.types import UsageWindowRow from app.core.utils.request_id import ensure_request_id, get_request_id @@ -86,6 +86,7 @@ from app.modules.proxy.repo_bundle import ProxyRepoFactory, ProxyRepositories from app.modules.proxy.request_policy import ( apply_api_key_enforcement, + openai_invalid_request_error, normalize_responses_request_payload, openai_invalid_payload_error, openai_validation_error, @@ -127,6 +128,14 @@ class _AffinityPolicy: max_age_seconds: int | None = None +@dataclass(frozen=True, slots=True) +class _ResolvedResponsesRequest: + payload: ResponsesRequest + current_input_items: list[JsonValue] + parent_response_id: str | None = None + preferred_account_id: str | None = None + + def _resolve_upstream_stream_transport(upstream_stream_transport: str) -> str | None: if upstream_stream_transport == "default": return None @@ -622,6 +631,14 @@ async def proxy_responses_websocket( _serialize_websocket_error_event(_app_error_to_websocket_event(exc)) ) continue + except ProxyResponseError as exc: + async with client_send_lock: + await websocket.send_text( + _serialize_websocket_error_event( + _wrapped_websocket_error_event(exc.status_code, exc.payload) + ) + ) + continue except ClientPayloadError as exc: async with client_send_lock: await websocket.send_text( @@ -705,19 +722,25 @@ async def proxy_responses_websocket( ) ) continue + connect_parameters = inspect.signature(self._connect_proxy_websocket).parameters + connect_kwargs = { + "sticky_key": request_affinity.key, + "sticky_kind": request_affinity.kind, + "reallocate_sticky": request_affinity.reallocate_sticky, + "sticky_max_age_seconds": request_affinity.max_age_seconds, + "prefer_earlier_reset": prefer_earlier_reset, + "routing_strategy": routing_strategy, + "model": request_state.model, + "request_state": request_state, + "api_key": api_key, + "client_send_lock": client_send_lock, + "websocket": websocket, + } + if "preferred_account_id" in connect_parameters: + connect_kwargs["preferred_account_id"] = request_state.preferred_account_id account, upstream = await self._connect_proxy_websocket( filtered_headers, - sticky_key=request_affinity.key, - sticky_kind=request_affinity.kind, - reallocate_sticky=request_affinity.reallocate_sticky, - sticky_max_age_seconds=request_affinity.max_age_seconds, - prefer_earlier_reset=prefer_earlier_reset, - routing_strategy=routing_strategy, - model=request_state.model, - request_state=request_state, - api_key=api_key, - client_send_lock=client_send_lock, - websocket=websocket, + **connect_kwargs, ) if upstream is None or account is None: if request_state_registered: @@ -817,6 +840,9 @@ async def _prepare_websocket_response_create_request( apply_api_key_enforcement(responses_payload, refreshed_api_key) validate_model_access(refreshed_api_key, responses_payload.model) + resolved_request = await self._resolve_previous_response_request(responses_payload) + responses_payload = resolved_request.payload + upstream_payload = dict(responses_payload.to_payload()) forwarded_service_tier = _normalize_service_tier_value(upstream_payload.get("service_tier")) reservation = await self._reserve_websocket_api_key_usage( @@ -834,6 +860,9 @@ async def _prepare_websocket_response_create_request( reasoning_effort=responses_payload.reasoning.effort if responses_payload.reasoning else None, api_key_reservation=reservation, started_at=time.monotonic(), + parent_response_id=resolved_request.parent_response_id, + preferred_account_id=resolved_request.preferred_account_id, + current_input_items=resolved_request.current_input_items, awaiting_response_created=True, ), affinity_policy=_sticky_key_for_responses_request( @@ -855,6 +884,7 @@ async def _connect_proxy_websocket( prefer_earlier_reset: bool, routing_strategy: RoutingStrategy, model: str | None, + preferred_account_id: str | None, request_state: _WebSocketRequestState, api_key: ApiKeyData | None, client_send_lock: anyio.Lock, @@ -875,6 +905,7 @@ async def _connect_proxy_websocket( prefer_earlier_reset_accounts=prefer_earlier_reset, routing_strategy=routing_strategy, model=model, + preferred_account_id=preferred_account_id, ) except ProxyResponseError as exc: if _is_proxy_budget_exhausted_error(exc): @@ -1279,6 +1310,7 @@ async def _process_upstream_websocket_text( actual_service_tier = _service_tier_from_event_payload(payload) if actual_service_tier is not None: request_state.service_tier = actual_service_tier + _collect_output_item_event(payload, request_state.output_items) if ( event_type in {"response.completed", "response.failed", "response.incomplete", "error"} and pending_requests @@ -1442,6 +1474,16 @@ async def _finalize_websocket_request_state( elif settlement.record_success: await self._load_balancer.record_success(account) + if event_type == "response.completed": + await self._persist_response_snapshot( + response_id=response_id, + parent_response_id=request_state.parent_response_id, + account_id=account_id_value, + model=request_state.model or "", + input_items=request_state.current_input_items, + response_payload=_terminal_response_payload(payload, request_state.output_items), + ) + latency_ms = int((time.monotonic() - request_state.started_at) * 1000) cached_input_tokens = usage.input_tokens_details.cached_tokens if usage and usage.input_tokens_details else None reasoning_tokens = ( @@ -1817,6 +1859,8 @@ async def _stream_with_retry( base_settings = get_settings() settings = await get_settings_cache().get() deadline = start + base_settings.proxy_request_budget_seconds + resolved_request = await self._resolve_previous_response_request(payload) + payload = resolved_request.payload prefer_earlier_reset = settings.prefer_earlier_reset_accounts upstream_stream_transport = _resolve_upstream_stream_transport(settings.upstream_stream_transport) affinity = _sticky_key_for_responses_request( @@ -1867,6 +1911,7 @@ async def _stream_with_retry( prefer_earlier_reset_accounts=prefer_earlier_reset, routing_strategy=routing_strategy, model=payload.model, + preferred_account_id=resolved_request.preferred_account_id, ) except ProxyResponseError as exc: error = _parse_openai_error(exc.payload) @@ -2011,6 +2056,8 @@ async def _stream_with_retry( attempt < max_attempts - 1, api_key=api_key, settlement=settlement, + parent_response_id=resolved_request.parent_response_id, + snapshot_input_items=resolved_request.current_input_items, suppress_text_done_events=suppress_text_done_events, upstream_stream_transport=upstream_stream_transport, request_transport=request_transport, @@ -2137,6 +2184,8 @@ async def _stream_with_retry( False, api_key=api_key, settlement=settlement, + parent_response_id=resolved_request.parent_response_id, + snapshot_input_items=resolved_request.current_input_items, suppress_text_done_events=suppress_text_done_events, upstream_stream_transport=upstream_stream_transport, request_transport=request_transport, @@ -2248,6 +2297,8 @@ async def _stream_once( *, api_key: ApiKeyData | None, settlement: _StreamSettlement, + parent_response_id: str | None, + snapshot_input_items: list[JsonValue], suppress_text_done_events: bool, upstream_stream_transport: str | None, request_transport: str, @@ -2266,6 +2317,9 @@ async def _stream_once( error_message = None usage = None saw_text_delta = False + output_items: dict[int, dict[str, JsonValue]] = {} + completed_response_payload: dict[str, JsonValue] | None = None + response_id: str | None = None try: if upstream_stream_transport is not None: @@ -2293,6 +2347,8 @@ async def _stream_once( first_payload = parse_sse_data_json(first) event = parse_sse_event(first) event_type = _event_type_from_payload(event, first_payload) + _collect_output_item_event(first_payload, output_items) + response_id = _websocket_response_id(event, first_payload) or response_id event_service_tier = _service_tier_from_event_payload(first_payload) if event_service_tier is not None: actual_service_tier = event_service_tier @@ -2330,6 +2386,7 @@ async def _stream_once( if event and event.type in ("response.completed", "response.incomplete"): usage = event.response.usage if event.response else None + completed_response_payload = _terminal_response_payload(first_payload, output_items) if event.type == "response.incomplete": status = "error" @@ -2349,6 +2406,8 @@ async def _stream_once( event_payload = parse_sse_data_json(line) event = parse_sse_event(line) event_type = _event_type_from_payload(event, event_payload) + _collect_output_item_event(event_payload, output_items) + response_id = _websocket_response_id(event, event_payload) or response_id event_service_tier = _service_tier_from_event_payload(event_payload) if event_service_tier is not None: actual_service_tier = event_service_tier @@ -2380,6 +2439,7 @@ async def _stream_once( settlement.account_health_error = _should_penalize_stream_error(error_code) if event_type in ("response.completed", "response.incomplete"): usage = event.response.usage if event.response else None + completed_response_payload = _terminal_response_payload(event_payload, output_items) if event_type == "response.incomplete": status = "error" yield line @@ -2433,6 +2493,15 @@ async def _stream_once( requested_service_tier=requested_service_tier, actual_service_tier=actual_service_tier, ) + if status == "success": + await self._persist_response_snapshot( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id_value, + model=model, + input_items=snapshot_input_items, + response_payload=completed_response_payload, + ) async def _write_request_log( self, @@ -2712,6 +2781,7 @@ async def _select_account_with_budget( routing_strategy: RoutingStrategy = "usage_weighted", model: str | None = None, additional_limit_name: str | None = None, + preferred_account_id: str | None = None, ) -> AccountSelection: remaining_budget = _remaining_budget_seconds(deadline) if remaining_budget <= 0: @@ -2730,11 +2800,111 @@ async def _select_account_with_budget( routing_strategy=routing_strategy, model=model, additional_limit_name=additional_limit_name, + preferred_account_id=preferred_account_id, ) except TimeoutError: logger.warning("%s account selection exceeded request budget request_id=%s", kind.title(), request_id) _raise_proxy_budget_exhausted() + async def _resolve_previous_response_request(self, payload: ResponsesRequest) -> _ResolvedResponsesRequest: + current_input_items = _clone_json_list(payload.input) + previous_response_id = payload.previous_response_id + if not previous_response_id: + return _ResolvedResponsesRequest(payload=payload, current_input_items=current_input_items) + + replay_items, preferred_account_id = await self._resolve_previous_response_chain(previous_response_id) + resolved_payload = payload.model_copy( + deep=True, + update={ + "input": [*replay_items, *current_input_items], + "previous_response_id": None, + }, + ) + return _ResolvedResponsesRequest( + payload=resolved_payload, + current_input_items=current_input_items, + parent_response_id=previous_response_id, + preferred_account_id=preferred_account_id, + ) + + async def _resolve_previous_response_chain(self, response_id: str) -> tuple[list[JsonValue], str | None]: + if not response_id: + _raise_unknown_previous_response_id() + + async with self._repo_factory() as repos: + if repos.response_snapshots is None: + _raise_unknown_previous_response_id() + chain: list[dict[str, str | None]] = [] + preferred_account_id: str | None = None + current_id = response_id + seen_response_ids: set[str] = set() + + while current_id: + if current_id in seen_response_ids: + _raise_unknown_previous_response_id() + seen_response_ids.add(current_id) + snapshot = await repos.response_snapshots.get(current_id) + if snapshot is None: + _raise_unknown_previous_response_id() + chain.append( + { + "response_id": snapshot.response_id, + "parent_response_id": snapshot.parent_response_id, + "account_id": snapshot.account_id, + "input_items_json": snapshot.input_items_json, + "response_json": snapshot.response_json, + } + ) + if preferred_account_id is None and snapshot.account_id: + preferred_account_id = snapshot.account_id + current_id = snapshot.parent_response_id or "" + + replay_items: list[JsonValue] = [] + for snapshot in reversed(chain): + snapshot_input = _decode_snapshot_json_list(snapshot["input_items_json"] or "[]") + snapshot_response = _decode_snapshot_json_mapping(snapshot["response_json"] or "{}") + replay_items.extend(snapshot_input) + replay_items.extend(_replayable_response_output_items(snapshot_response)) + return replay_items, preferred_account_id + + async def _persist_response_snapshot( + self, + *, + response_id: str | None, + parent_response_id: str | None, + account_id: str | None, + model: str, + input_items: list[JsonValue], + response_payload: dict[str, JsonValue] | None, + ) -> None: + if not response_id or response_payload is None: + return + response_payload = dict(response_payload) + response_payload.setdefault("id", response_id) + if parse_response_payload(response_payload) is None: + return + + with anyio.CancelScope(shield=True): + try: + async with self._repo_factory() as repos: + if repos.response_snapshots is None: + return + await repos.response_snapshots.upsert( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id, + model=model, + input_items=input_items, + response_payload=response_payload, + ) + except Exception: + logger.warning( + "Failed to persist response snapshot response_id=%s parent_response_id=%s", + response_id, + parent_response_id, + exc_info=True, + ) + async def _handle_proxy_error(self, account: Account, exc: ProxyResponseError) -> None: error = _parse_openai_error(exc.payload) code = _normalize_error_code( @@ -2828,6 +2998,10 @@ class _WebSocketRequestState: api_key_reservation: ApiKeyUsageReservationData | None started_at: float response_id: str | None = None + parent_response_id: str | None = None + preferred_account_id: str | None = None + current_input_items: list[JsonValue] = field(default_factory=list) + output_items: dict[int, dict[str, JsonValue]] = field(default_factory=dict) awaiting_response_created: bool = False @@ -2877,6 +3051,105 @@ def _websocket_response_id(event: OpenAIEvent | None, payload: dict[str, JsonVal return stripped or None +def _collect_output_item_event( + payload: dict[str, JsonValue] | None, + output_items: dict[int, dict[str, JsonValue]], +) -> None: + if not isinstance(payload, dict): + return + event_type = payload.get("type") + if event_type not in ("response.output_item.added", "response.output_item.done"): + return + output_index = payload.get("output_index") + item = payload.get("item") + if not isinstance(output_index, int) or not isinstance(item, dict): + return + output_items[output_index] = dict(item) + + +def _merge_response_output_items( + response: dict[str, JsonValue], + output_items: dict[int, dict[str, JsonValue]], +) -> dict[str, JsonValue]: + merged = dict(response) + existing_output = response.get("output") + if isinstance(existing_output, list) and existing_output: + return merged + if output_items: + merged["output"] = [item for _, item in sorted(output_items.items())] + return merged + + +def _terminal_response_payload( + payload: dict[str, JsonValue] | None, + output_items: dict[int, dict[str, JsonValue]], +) -> dict[str, JsonValue] | None: + if not isinstance(payload, dict): + return None + response = payload.get("response") + if not isinstance(response, dict): + return None + return _merge_response_output_items(response, output_items) + + +def _replayable_response_output_items(response_payload: dict[str, JsonValue]) -> list[JsonValue]: + output_value = response_payload.get("output") + if not isinstance(output_value, list): + _raise_unknown_previous_response_id() + filtered_output = [ + item + for item in output_value + if not (isinstance(item, dict) and item.get("type") == "reasoning") + ] + normalized_output = _sanitize_input_items(filtered_output) + replay_items: list[JsonValue] = [] + for item in normalized_output: + if isinstance(item, dict) and item.get("role") == "assistant": + replay_items.append({"role": "assistant", "content": item.get("content")}) + continue + replay_items.append(item) + return replay_items + + +def _clone_json_list(value: JsonValue) -> list[JsonValue]: + if not isinstance(value, list): + raise TypeError("expected list input items") + return json.loads(json.dumps(value, ensure_ascii=False)) + + +def _decode_snapshot_json_list(value: str) -> list[JsonValue]: + try: + decoded = json.loads(value) + except json.JSONDecodeError as exc: + raise ProxyResponseError( + 400, + openai_invalid_request_error("Unknown previous_response_id", param="previous_response_id"), + ) from exc + if not isinstance(decoded, list): + _raise_unknown_previous_response_id() + return decoded + + +def _decode_snapshot_json_mapping(value: str) -> dict[str, JsonValue]: + try: + decoded = json.loads(value) + except json.JSONDecodeError as exc: + raise ProxyResponseError( + 400, + openai_invalid_request_error("Unknown previous_response_id", param="previous_response_id"), + ) from exc + if not isinstance(decoded, dict): + _raise_unknown_previous_response_id() + return decoded + + +def _raise_unknown_previous_response_id() -> NoReturn: + raise ProxyResponseError( + 400, + openai_invalid_request_error("Unknown previous_response_id", param="previous_response_id"), + ) + + def _find_websocket_request_state_by_response_id( pending_requests: deque[_WebSocketRequestState], response_id: str, diff --git a/openspec/changes/support-previous-response-id-persistence/design.md b/openspec/changes/support-previous-response-id-persistence/design.md new file mode 100644 index 00000000..5eaebfde --- /dev/null +++ b/openspec/changes/support-previous-response-id-persistence/design.md @@ -0,0 +1,24 @@ +## Overview + +This change makes `previous_response_id` a proxy-local feature instead of an upstream passthrough. The proxy will persist the normalized request/response state required to reconstruct a prior conversation turn, then rebuild the upstream `input` array whenever a new request references an earlier response id. + +## Decisions + +### Persist snapshots in a dedicated table instead of reusing request logs + +`request_logs` only stores metrics and cannot reconstruct input/output history. Add a dedicated snapshot table keyed by `response_id` with `parent_response_id`, `account_id`, normalized turn input JSON, and terminal response JSON. Keep `account_id` as plain text instead of a cascading foreign key so deleted accounts do not erase replay state. + +### Replay prior turn input/output, not prior instructions + +OpenAI Responses semantics do not carry forward prior `instructions` when a request uses `previous_response_id`. The resolver will recursively flatten `turn_input + prior_response.output` for each parent snapshot, then append the current request input while leaving the current request's `instructions` untouched. + +### Prefer the previous account without adding a new sticky-session kind + +The stored snapshot already includes the prior `account_id`, so adding a parallel sticky row would duplicate routing state. Extend account selection with an optional preferred account: use it when the account is still eligible for the current request, otherwise log the miss and fall back to the existing selection path. + +### Persist snapshots from shared stream settlement state + +Collected non-stream responses already reconstruct terminal `response.output` from `response.output_item.*` events, but streaming paths do not preserve that state. Move output-item accumulation into a shared helper and carry snapshot metadata through HTTP stream settlement and WebSocket request state so every successful terminal response can persist the same canonical snapshot payload. + +## Verification + diff --git a/openspec/changes/support-previous-response-id-persistence/proposal.md b/openspec/changes/support-previous-response-id-persistence/proposal.md new file mode 100644 index 00000000..27cf6413 --- /dev/null +++ b/openspec/changes/support-previous-response-id-persistence/proposal.md @@ -0,0 +1,30 @@ +## Why + +`codex-lb` currently rejects `previous_response_id` outright even though OpenAI-style Responses clients use it for multi-turn conversation state. That makes `/v1/responses` incompatible with clients that continue a conversation by referencing the last response id instead of resending the full transcript. + +The proxy cannot delegate this feature to the ChatGPT upstream because upstream does not accept `previous_response_id`. To close the compatibility gap, the proxy must persist enough local conversation state to rebuild history after restart and replay it upstream as explicit input items. + +## What Changes + +- persist response-chain snapshots keyed by `response_id` so `previous_response_id` survives restarts +- resolve `previous_response_id` into replayable input history before forwarding upstream +- prefer the account that served the referenced response when it remains eligible, while falling back to normal routing if it does not +- support the same behavior for HTTP streaming, HTTP collected responses, and WebSocket Responses traffic +- add migration and regression coverage for snapshot persistence, chain resolution, and invalid `previous_response_id` failures + +## Capabilities + +### New Capabilities + +- `database-migrations`: durable response-chain snapshots for `previous_response_id` replay + +### Modified Capabilities + +- `responses-api-compat`: `/v1/responses` and WebSocket Responses requests may continue prior proxy-managed conversations via `previous_response_id` +- `responses-api-compat`: `previous_response_id` resolution prefers the prior account but falls back to normal routing when necessary + +## Impact + +- Code: proxy request normalization, stream/websocket response settlement, load balancer account selection, new response snapshot repository/service, Alembic migration +- Tests: Responses compatibility integration coverage, proxy routing/unit coverage, migration coverage +- Specs: `openspec/specs/responses-api-compat/spec.md`, `openspec/specs/database-migrations/spec.md` diff --git a/openspec/changes/support-previous-response-id-persistence/specs/database-migrations/spec.md b/openspec/changes/support-previous-response-id-persistence/specs/database-migrations/spec.md new file mode 100644 index 00000000..5128068d --- /dev/null +++ b/openspec/changes/support-previous-response-id-persistence/specs/database-migrations/spec.md @@ -0,0 +1,9 @@ +## MODIFIED Requirements + +### Requirement: Proxy durable state migrations remain additive and reproducible +Database migrations for proxy-managed durable state MUST create the tables and indexes required by new runtime capabilities without depending on mutable runtime configuration. + +#### Scenario: response snapshot table is created on upgrade +- **WHEN** the application migrates a database that predates durable `previous_response_id` support +- **THEN** the migration creates a response snapshot table keyed by `response_id` +- **AND** the table includes indexed parent linkage required for recursive chain resolution after restart diff --git a/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md b/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md new file mode 100644 index 00000000..07667a6f --- /dev/null +++ b/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md @@ -0,0 +1,24 @@ +## MODIFIED Requirements + +### Requirement: Support Responses input types and conversation constraints +The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. When a client supplies `previous_response_id`, the service MUST resolve that id from proxy-managed durable response snapshots, rebuild the prior conversation input/output history as explicit upstream input items, and continue to reject requests that include both `conversation` and `previous_response_id`. + +#### Scenario: previous response id resolves to replayable history +- **WHEN** the client sends `previous_response_id` that matches a persisted prior response snapshot +- **THEN** the proxy rebuilds the prior chain as upstream `input` items before appending the current request input +- **AND** the current request's `instructions` remain the only top-level instructions forwarded upstream + +#### Scenario: unknown previous response id +- **WHEN** the client sends `previous_response_id` that does not match a persisted prior response snapshot +- **THEN** the service returns a 400 OpenAI-format error envelope with `param=previous_response_id` + +#### Scenario: conversation and previous response id conflict +- **WHEN** the client provides both `conversation` and `previous_response_id` +- **THEN** the service returns a 4xx response with an OpenAI error envelope indicating invalid parameters + +### Requirement: Previous-response replay preserves routing continuity +When a request resolves `previous_response_id`, the service MUST prefer the account that served the referenced response if that account is still eligible for the current request. If the stored account is unavailable, the service MUST fall back to the existing account-selection flow instead of failing solely because the preferred account cannot serve the request. + +#### Scenario: preferred prior account still eligible +- **WHEN** the client sends `previous_response_id` that resolves to a snapshot whose account can still serve the current request +- **THEN** the proxy routes the request to that same account diff --git a/openspec/changes/support-previous-response-id-persistence/tasks.md b/openspec/changes/support-previous-response-id-persistence/tasks.md new file mode 100644 index 00000000..b3936ee5 --- /dev/null +++ b/openspec/changes/support-previous-response-id-persistence/tasks.md @@ -0,0 +1,24 @@ +## 1. Snapshot persistence + +- [x] 1.1 Add a durable response snapshot table and repository for per-response request/response chain state +- [x] 1.2 Wire snapshot repository access into proxy dependencies and service flows +- [x] 1.3 Add migration and regression coverage for snapshot table creation + +## 2. `previous_response_id` resolution + +- [x] 2.1 Stop rejecting `previous_response_id` by default while continuing to reject `conversation` plus `previous_response_id` +- [x] 2.2 Resolve `previous_response_id` into replayable upstream input history without carrying prior instructions +- [x] 2.3 Return explicit OpenAI-format errors when `previous_response_id` cannot be resolved from persisted snapshots + +## 3. Routing continuity + +- [x] 3.1 Prefer the stored prior account for resolved `previous_response_id` requests +- [x] 3.2 Fall back to normal account selection when the preferred account is unavailable or ineligible +- [x] 3.3 Add regression coverage for prefer-with-fallback routing + +## 4. Stream and websocket parity + +- [x] 4.1 Persist snapshots for HTTP streaming and collected `/v1/responses` requests from shared output-item accumulation +- [x] 4.2 Persist snapshots for WebSocket Responses requests and reuse them on follow-up creates +- [x] 4.3 Add integration coverage for HTTP and WebSocket chain continuity across service restart + diff --git a/openspec/specs/responses-api-compat/context.md b/openspec/specs/responses-api-compat/context.md index d94dc50e..5663360a 100644 --- a/openspec/specs/responses-api-compat/context.md +++ b/openspec/specs/responses-api-compat/context.md @@ -2,7 +2,7 @@ ## Purpose and Scope -This capability implements OpenAI-compatible behavior for `POST /v1/responses`, including request validation, streaming events, non-streaming aggregation, and OpenAI-style error envelopes. The scope is limited to what the ChatGPT upstream can provide; unsupported features are explicitly rejected. +This capability implements OpenAI-compatible behavior for `POST /v1/responses`, including request validation, streaming events, non-streaming aggregation, durable `previous_response_id` replay, and OpenAI-style error envelopes. The scope is limited to what the ChatGPT upstream can provide; unsupported features are explicitly rejected unless the proxy can emulate them locally. See `openspec/specs/responses-api-compat/spec.md` for normative requirements. @@ -18,7 +18,7 @@ See `openspec/specs/responses-api-compat/spec.md` for normative requirements. - Upstream limitations determine available modalities, tool output, and overflow handling. - `store=true` is rejected; responses are not persisted. - `include` values must be on the documented allowlist. -- `previous_response_id` and `truncation` are rejected. +- `previous_response_id` is implemented as a proxy-local feature backed by durable response snapshots; `truncation` is still rejected. - `/v1/responses/compact` keeps a final-JSON contract and preserves the raw upstream `/codex/responses/compact` payload shape as the canonical next context window instead of rewriting it through buffered `/codex/responses` streaming. - Compact transport failures fail closed with respect to semantics: no surrogate `/codex/responses` fallback and no local compact-window reconstruction. - Compact transport may use bounded same-contract retries only for safe pre-body transport failures and `401 -> refresh -> retry`. @@ -41,6 +41,7 @@ See `openspec/specs/responses-api-compat/spec.md` for normative requirements. - **Upstream error / no accounts:** Non-streaming responses return an OpenAI error envelope with 5xx status. - **Compact upstream transport/client failure:** Retry only inside `/codex/responses/compact` when the failure is safely retryable; otherwise return an explicit upstream error without surrogate fallback. - **Invalid request payloads:** Return 4xx with `invalid_request_error`. +- **Unknown `previous_response_id`:** Return 4xx with `invalid_request_error` and `param=previous_response_id`. ## Error Envelope Mapping (Reference) diff --git a/openspec/specs/responses-api-compat/spec.md b/openspec/specs/responses-api-compat/spec.md index f57e6197..cc2a7a02 100644 --- a/openspec/specs/responses-api-compat/spec.md +++ b/openspec/specs/responses-api-compat/spec.md @@ -16,7 +16,7 @@ The service MUST accept POST requests to `/v1/responses` with a JSON body and MU - **THEN** the service returns a 4xx response with an OpenAI error envelope describing the invalid parameter ### Requirement: Support Responses input types and conversation constraints -The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. The service MUST reject `previous_response_id` with an OpenAI error envelope because upstream does not support it. The service MUST continue to reject requests that include both `conversation` and `previous_response_id`. +The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. When the client supplies `previous_response_id`, the service MUST resolve that id from proxy-managed durable response snapshots, rebuild the prior conversation input/output history as explicit upstream input items, and continue to reject requests that include both `conversation` and `previous_response_id`. #### Scenario: String input - **WHEN** the client sends `input` as a string @@ -26,13 +26,25 @@ The service MUST accept `input` as either a string or an array of input items. W - **WHEN** the client sends `input` as an array of input items - **THEN** the request is accepted and each item is forwarded in order +#### Scenario: previous_response_id resolved from durable snapshots +- **WHEN** the client provides `previous_response_id` that matches a persisted prior response snapshot +- **THEN** the service forwards the rebuilt prior input/output history before the current request input +- **AND** it does not carry forward prior `instructions` + #### Scenario: conversation and previous_response_id conflict - **WHEN** the client provides both `conversation` and `previous_response_id` - **THEN** the service returns a 4xx response with an OpenAI error envelope indicating invalid parameters -#### Scenario: previous_response_id provided -- **WHEN** the client provides `previous_response_id` -- **THEN** the service returns a 4xx response with an OpenAI error envelope indicating the unsupported parameter +### Requirement: Prefer prior account continuity for resolved previous_response_id +When a request resolves `previous_response_id`, the service MUST prefer the account that served the referenced response if that account is still eligible for the current request. If the stored account is unavailable, the service MUST fall back to the existing account-selection flow instead of failing solely because the preferred account cannot serve the request. + +#### Scenario: Preferred prior account remains eligible +- **WHEN** the client sends `previous_response_id` that resolves to a snapshot whose account can still serve the current request +- **THEN** the service routes the request to that account ahead of normal balancing + +#### Scenario: Preferred prior account unavailable +- **WHEN** the client sends `previous_response_id` that resolves to a snapshot whose account can no longer serve the current request +- **THEN** the service falls back to normal account selection without returning an error solely because the preferred account is unavailable ### Requirement: Reject input_file file_id in Responses The service MUST reject `input_file.file_id` in Responses input items and return a 4xx OpenAI invalid_request_error with message "Invalid request payload". diff --git a/tests/integration/test_load_balancer_integration.py b/tests/integration/test_load_balancer_integration.py index 19bb5c07..d61df488 100644 --- a/tests/integration/test_load_balancer_integration.py +++ b/tests/integration/test_load_balancer_integration.py @@ -427,3 +427,117 @@ async def test_load_balancer_filters_accounts_by_persisted_additional_usage(db_s assert selection.account is not None assert selection.account.id == eligible_account.id + + +@pytest.mark.asyncio +async def test_load_balancer_prefers_requested_account_when_eligible(db_setup): + encryptor = TokenEncryptor() + now = utcnow() + reset_at = int(now.replace(tzinfo=timezone.utc).timestamp()) + 3600 + + preferred = Account( + id="acc_preferred_eligible", + email="preferred-eligible@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("preferred-access"), + refresh_token_encrypted=encryptor.encrypt("preferred-refresh"), + id_token_encrypted=encryptor.encrypt("preferred-id"), + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + other = Account( + id="acc_preferred_other", + email="preferred-other@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("other-access"), + refresh_token_encrypted=encryptor.encrypt("other-refresh"), + id_token_encrypted=encryptor.encrypt("other-id"), + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + + async with SessionLocal() as session: + accounts_repo = AccountsRepository(session) + usage_repo = UsageRepository(session) + await accounts_repo.upsert(preferred) + await accounts_repo.upsert(other) + await usage_repo.add_entry( + account_id=preferred.id, + used_percent=80.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + await usage_repo.add_entry( + account_id=other.id, + used_percent=5.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + + balancer = LoadBalancer(_repo_factory) + selection = await balancer.select_account(preferred_account_id=preferred.id) + + assert selection.account is not None + assert selection.account.id == preferred.id + + +@pytest.mark.asyncio +async def test_load_balancer_falls_back_when_preferred_account_unavailable(db_setup): + encryptor = TokenEncryptor() + now = utcnow() + now_epoch = int(now.replace(tzinfo=timezone.utc).timestamp()) + reset_at = now_epoch + 3600 + + preferred = Account( + id="acc_preferred_blocked", + email="preferred-blocked@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("preferred-blocked-access"), + refresh_token_encrypted=encryptor.encrypt("preferred-blocked-refresh"), + id_token_encrypted=encryptor.encrypt("preferred-blocked-id"), + last_refresh=now, + status=AccountStatus.RATE_LIMITED, + deactivation_reason=None, + reset_at=reset_at, + ) + fallback = Account( + id="acc_preferred_fallback", + email="preferred-fallback@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("fallback-access"), + refresh_token_encrypted=encryptor.encrypt("fallback-refresh"), + id_token_encrypted=encryptor.encrypt("fallback-id"), + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + + async with SessionLocal() as session: + accounts_repo = AccountsRepository(session) + usage_repo = UsageRepository(session) + await accounts_repo.upsert(preferred) + await accounts_repo.upsert(fallback) + await usage_repo.add_entry( + account_id=preferred.id, + used_percent=10.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + await usage_repo.add_entry( + account_id=fallback.id, + used_percent=40.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + + balancer = LoadBalancer(_repo_factory) + selection = await balancer.select_account(preferred_account_id=preferred.id) + + assert selection.account is not None + assert selection.account.id == fallback.id diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 62598c52..4d1e1fe6 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +import sqlalchemy as sa from sqlalchemy import text from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -60,6 +61,30 @@ def _make_account(account_id: str, email: str, plan_type: str) -> Account: ) +@pytest.mark.asyncio +async def test_run_startup_migrations_creates_response_snapshots_table(db_setup): + result = await run_startup_migrations(_DATABASE_URL) + assert result.current_revision == _HEAD_REVISION + + async with SessionLocal() as session: + def _inspect_tables(sync_session): + inspector = sa.inspect(sync_session.connection()) + return inspector.get_columns("response_snapshots") + + columns = await session.run_sync(_inspect_tables) + + column_names = {str(column["name"]) for column in columns} + assert { + "response_id", + "parent_response_id", + "account_id", + "model", + "input_items_json", + "response_json", + "created_at", + }.issubset(column_names) + + @pytest.mark.asyncio async def test_run_startup_migrations_preserves_unknown_plan_types(db_setup): async with SessionLocal() as session: diff --git a/tests/integration/test_openai_compat_features.py b/tests/integration/test_openai_compat_features.py index f0c0c642..a2c1d1bb 100644 --- a/tests/integration/test_openai_compat_features.py +++ b/tests/integration/test_openai_compat_features.py @@ -95,20 +95,63 @@ async def test_v1_responses_rejects_input_file_id(async_client): @pytest.mark.asyncio -async def test_v1_responses_rejects_previous_response_id(async_client): +async def test_v1_responses_unknown_previous_response_id_errors(async_client): payload = { "model": "gpt-5.2", "previous_response_id": "resp_abc123", - "input": [ - { - "role": "user", - "content": [{"type": "input_text", "text": "Continue."}], - } - ], + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue."}]}], } resp = await async_client.post("/v1/responses", json=payload) assert resp.status_code == 400 - assert resp.json()["error"]["type"] == "invalid_request_error" + error = resp.json()["error"] + assert error["type"] == "invalid_request_error" + assert error["param"] == "previous_response_id" + assert error["message"] == "Unknown previous_response_id" + + +@pytest.mark.asyncio +async def test_v1_responses_replays_previous_response_after_restart(async_client, app_instance, monkeypatch): + await _import_account(async_client, "acc_prev_response", "prev-response@example.com") + + seen_inputs: list[object] = [] + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kwargs): + seen_inputs.append(payload.input) + if len(seen_inputs) == 1: + yield ( + 'data: {"type":"response.output_item.done","output_index":0,' + '"item":{"id":"msg_prev","type":"message","role":"assistant",' + '"content":[{"type":"output_text","text":"Prior answer"}]}}\n\n' + ) + yield _completed_event("resp_prev") + return + yield _completed_event("resp_followup") + + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) + + first = await async_client.post("/v1/responses", json={"model": "gpt-5.2", "input": "Hello"}) + assert first.status_code == 200 + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.2", + "previous_response_id": "resp_prev", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + ) + assert second.status_code == 200 + assert seen_inputs == [ + [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, + ], + ] @pytest.mark.asyncio diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 0d7d2a19..d7e884ec 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -89,6 +89,172 @@ def _websocket_settings(**overrides): return SimpleNamespace(**values) +def test_v1_responses_websocket_replays_previous_response_after_restart(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "msg_ws_prev", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Prior answer"}], + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_prev", + "object": "response", + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + }, + separators=(",", ":"), + ), + ), + ] + ) + second_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_next", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_next", + "object": "response", + "status": "completed", + "usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}, + }, + }, + separators=(",", ":"), + ), + ), + ] + ) + upstreams = deque([first_upstream, second_upstream]) + connect_calls: list[dict[str, object]] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + preferred_account_id=None, + ): + del self, headers, sticky_key, sticky_kind, reallocate_sticky, sticky_max_age_seconds + del prefer_earlier_reset, routing_strategy, model, api_key, client_send_lock, websocket + connect_calls.append( + { + "request_id": request_state.request_id, + "preferred_account_id": preferred_account_id, + } + ) + return SimpleNamespace(id="acct_ws_prev"), upstreams.popleft() + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + with TestClient(app_instance) as client: + with client.websocket_connect("/v1/responses") as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.2", "input": "Hello"})) + assert json.loads(websocket.receive_text())["type"] == "response.created" + assert json.loads(websocket.receive_text())["type"] == "response.output_item.done" + assert json.loads(websocket.receive_text())["type"] == "response.completed" + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + with client.websocket_connect("/v1/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.2", + "previous_response_id": "resp_ws_prev", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + } + ) + ) + assert json.loads(websocket.receive_text())["type"] == "response.created" + assert json.loads(websocket.receive_text())["type"] == "response.completed" + + assert connect_calls[1]["preferred_account_id"] == "acct_ws_prev" + assert [json.loads(message) for message in second_upstream.sent_text] == [ + { + "type": "response.create", + "model": "gpt-5.2", + "instructions": "", + "input": [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, + ], + "tools": [], + "store": False, + "include": [], + } + ] + + def test_backend_responses_websocket_proxies_upstream_and_persists_log(app_instance, monkeypatch): upstream_messages = [ _FakeUpstreamMessage( diff --git a/tests/unit/test_openai_requests.py b/tests/unit/test_openai_requests.py index a12cc5eb..d8ebbce9 100644 --- a/tests/unit/test_openai_requests.py +++ b/tests/unit/test_openai_requests.py @@ -362,7 +362,18 @@ def test_responses_accepts_known_include_values(): assert request.include == ["reasoning.encrypted_content", "web_search_call.action.sources"] -def test_responses_rejects_conversation_previous_response_id(): +def test_responses_accepts_previous_response_id_without_conversation(): + payload = { + "model": "gpt-5.1", + "instructions": "hi", + "input": [], + "previous_response_id": "resp_1", + } + request = ResponsesRequest.model_validate(payload) + assert request.previous_response_id == "resp_1" + + +def test_responses_rejects_conversation_and_previous_response_id_together(): payload = { "model": "gpt-5.1", "instructions": "hi", @@ -370,7 +381,7 @@ def test_responses_rejects_conversation_previous_response_id(): "conversation": "conv_1", "previous_response_id": "resp_1", } - with pytest.raises(ValueError, match="previous_response_id is not supported"): + with pytest.raises(ValueError, match="Provide either 'conversation' or 'previous_response_id', not both"): ResponsesRequest.model_validate(payload) diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index bbdab539..8a0761a4 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -297,9 +297,39 @@ async def add_log(self, **kwargs: object) -> None: self.calls.append(dict(kwargs)) +class _ResponseSnapshotsStub: + def __init__(self, snapshots: dict[str, object] | None = None) -> None: + self._snapshots = snapshots or {} + self.upserts: list[dict[str, object]] = [] + + async def get(self, response_id: str): + return self._snapshots.get(response_id) + + async def upsert(self, **kwargs: object) -> object: + self.upserts.append(dict(kwargs)) + snapshot = SimpleNamespace( + response_id=kwargs["response_id"], + parent_response_id=kwargs.get("parent_response_id"), + account_id=kwargs.get("account_id"), + model=kwargs["model"], + input_items_json=json.dumps(kwargs["input_items"], separators=(",", ":")), + response_json=json.dumps(kwargs["response_payload"], separators=(",", ":")), + ) + self._snapshots[str(kwargs["response_id"])] = snapshot + return snapshot + + class _RepoContext: - def __init__(self, request_logs: _RequestLogsRecorder) -> None: - self._repos = SimpleNamespace(request_logs=request_logs) + def __init__( + self, + request_logs: _RequestLogsRecorder, + *, + response_snapshots: _ResponseSnapshotsStub | None = None, + ) -> None: + self._repos = SimpleNamespace( + request_logs=request_logs, + response_snapshots=response_snapshots, + ) async def __aenter__(self) -> object: return self._repos @@ -308,9 +338,13 @@ async def __aexit__(self, exc_type, exc, tb) -> bool: return False -def _repo_factory(request_logs: _RequestLogsRecorder): +def _repo_factory( + request_logs: _RequestLogsRecorder, + *, + response_snapshots: _ResponseSnapshotsStub | None = None, +): def factory() -> _RepoContext: - return _RepoContext(request_logs) + return _RepoContext(request_logs, response_snapshots=response_snapshots) return factory @@ -2790,6 +2824,68 @@ async def test_connect_proxy_websocket_maps_handshake_budget_exhaustion_to_timeo assert request_logs.calls[0]["error_code"] == "upstream_request_timeout" +@pytest.mark.asyncio +async def test_prepare_websocket_response_create_request_rebuilds_previous_response_history(monkeypatch): + request_logs = _RequestLogsRecorder() + response_snapshots = _ResponseSnapshotsStub( + { + "resp_prev": SimpleNamespace( + response_id="resp_prev", + parent_response_id=None, + account_id="acc_prev", + input_items_json=json.dumps( + [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + separators=(",", ":"), + ), + response_json=json.dumps( + { + "id": "resp_prev", + "output": [ + { + "id": "msg_prev", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Prior answer"}], + } + ], + }, + separators=(",", ":"), + ), + ) + } + ) + service = proxy_service.ProxyService(_repo_factory(request_logs, response_snapshots=response_snapshots)) + + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_refresh_websocket_api_key_policy", AsyncMock(return_value=None)) + + prepared = await service._prepare_websocket_response_create_request( + { + "type": "response.create", + "model": "gpt-5.2", + "previous_response_id": "resp_prev", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + headers={}, + codex_session_affinity=False, + openai_cache_affinity=True, + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=300, + api_key=None, + ) + + assert prepared.request_state.parent_response_id == "resp_prev" + assert prepared.request_state.preferred_account_id == "acc_prev" + assert prepared.request_state.current_input_items == [ + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]} + ] + assert json.loads(prepared.text_data)["input"] == [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, + ] + + @pytest.mark.asyncio async def test_prepare_websocket_response_create_request_normalizes_payload_and_reserves_forwarded_tier(monkeypatch): request_logs = _RequestLogsRecorder() diff --git a/uv.lock b/uv.lock index 8ec940dc..1cc74687 100644 --- a/uv.lock +++ b/uv.lock @@ -368,7 +368,7 @@ wheels = [ [[package]] name = "codex-lb" -version = "1.4.1" +version = "1.5.3" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From eb8ac27e719c51ba49dadca85e733cc03f0a6828 Mon Sep 17 00:00:00 2001 From: xirothedev Date: Sun, 15 Mar 2026 19:06:49 +0700 Subject: [PATCH 2/7] feat(db): switch runtime to neon postgresql Require PostgreSQL runtime configuration, add a dedicated migration DSN, and remove SQLite-specific startup behavior from the main runtime path. Normalize asyncpg SSL query parameters for Docker deployments and update OpenSpec, docs, compose, and tests for the Neon-first database contract. --- .env.example | 13 +- README.md | 12 +- app/core/config/settings.py | 36 +-- app/db/alembic/env.py | 4 +- app/db/migrate.py | 5 +- app/db/session.py | 185 +++------------ docker-compose.yml | 20 -- .../neon-postgresql-runtime/proposal.md | 15 ++ .../specs/database-backends/spec.md | 31 +++ .../specs/database-migrations/spec.md | 16 ++ .../changes/neon-postgresql-runtime/tasks.md | 4 + openspec/specs/database-backends/context.md | 32 ++- openspec/specs/database-backends/spec.md | 63 ++--- openspec/specs/database-migrations/context.md | 4 +- openspec/specs/database-migrations/spec.md | 20 +- openspec/specs/query-caching/context.md | 2 +- tests/conftest.py | 17 +- tests/unit/test_db_session.py | 215 ++++++------------ 18 files changed, 264 insertions(+), 430 deletions(-) create mode 100644 openspec/changes/neon-postgresql-runtime/proposal.md create mode 100644 openspec/changes/neon-postgresql-runtime/specs/database-backends/spec.md create mode 100644 openspec/changes/neon-postgresql-runtime/specs/database-migrations/spec.md create mode 100644 openspec/changes/neon-postgresql-runtime/tasks.md diff --git a/.env.example b/.env.example index d2781fd6..dcf437d6 100644 --- a/.env.example +++ b/.env.example @@ -1,10 +1,7 @@ -# Database -CODEX_LB_DATABASE_URL=sqlite+aiosqlite:///~/.codex-lb/store.db -# Optional PostgreSQL example (SQLite stays default if not set): -# CODEX_LB_DATABASE_URL=postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb +# Database (Neon PostgreSQL) +CODEX_LB_DATABASE_URL=postgresql+asyncpg://USER:PASSWORD@EP-POOLER.REGION.aws.neon.tech/codex_lb?sslmode=require +CODEX_LB_DATABASE_MIGRATION_URL=postgresql+asyncpg://USER:PASSWORD@EP-DIRECT.REGION.aws.neon.tech/codex_lb?sslmode=require CODEX_LB_DATABASE_MIGRATE_ON_STARTUP=true -CODEX_LB_DATABASE_SQLITE_PRE_MIGRATE_BACKUP_ENABLED=true -CODEX_LB_DATABASE_SQLITE_PRE_MIGRATE_BACKUP_MAX_FILES=5 # Upstream ChatGPT base URL (no /codex suffix) CODEX_LB_UPSTREAM_BASE_URL=https://chatgpt.com/backend-api @@ -26,8 +23,8 @@ CODEX_LB_OAUTH_CALLBACK_PORT=1455 CODEX_LB_TOKEN_REFRESH_TIMEOUT_SECONDS=30 CODEX_LB_TOKEN_REFRESH_INTERVAL_DAYS=8 -# Encryption key file (optional override; recommended for Docker volumes) -# CODEX_LB_ENCRYPTION_KEY_FILE=/var/lib/codex-lb/encryption.key +# Encryption key file (recommended for Docker volumes) +CODEX_LB_ENCRYPTION_KEY_FILE=/var/lib/codex-lb/encryption.key # Upstream usage fetch CODEX_LB_USAGE_FETCH_TIMEOUT_SECONDS=10 diff --git a/README.md b/README.md index 3bbe081e..5f1ada55 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,13 @@ docker volume create codex-lb-data docker run -d --name codex-lb \ -p 2455:2455 -p 1455:1455 \ -v codex-lb-data:/var/lib/codex-lb \ + -e CODEX_LB_DATABASE_URL=postgresql+asyncpg://USER:PASSWORD@EP-POOLER.REGION.aws.neon.tech/codex_lb?sslmode=require \ + -e CODEX_LB_DATABASE_MIGRATION_URL=postgresql+asyncpg://USER:PASSWORD@EP-DIRECT.REGION.aws.neon.tech/codex_lb?sslmode=require \ ghcr.io/soju06/codex-lb:latest # or uvx +CODEX_LB_DATABASE_URL=postgresql+asyncpg://USER:PASSWORD@EP-POOLER.REGION.aws.neon.tech/codex_lb?sslmode=require \ +CODEX_LB_DATABASE_MIGRATION_URL=postgresql+asyncpg://USER:PASSWORD@EP-DIRECT.REGION.aws.neon.tech/codex_lb?sslmode=require \ uvx codex-lb ``` @@ -289,16 +293,16 @@ Authorization: Bearer sk-clb-... Environment variables with `CODEX_LB_` prefix or `.env.local`. See [`.env.example`](.env.example). Dashboard auth is configured in Settings. -SQLite is the default database backend; PostgreSQL is optional via `CODEX_LB_DATABASE_URL` (for example `postgresql+asyncpg://...`). +Neon PostgreSQL is required for runtime persistence. Set `CODEX_LB_DATABASE_URL` to the pooled Neon DSN and `CODEX_LB_DATABASE_MIGRATION_URL` to the direct Neon DSN used by Alembic/startup migrations. ## Data -| Environment | Path | -|-------------|------| +| Environment | Local files | +|-------------|-------------| | Local / uvx | `~/.codex-lb/` | | Docker | `/var/lib/codex-lb/` | -Backup this directory to preserve your data. +These local paths primarily store the encryption key and other local runtime files. Application data lives in Neon PostgreSQL, so backup and recovery must include your Neon database, not just the local directory. ## Development diff --git a/app/core/config/settings.py b/app/core/config/settings.py index bfdced5d..a8385ca3 100644 --- a/app/core/config/settings.py +++ b/app/core/config/settings.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Annotated, Literal -from pydantic import Field, field_validator +from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict BASE_DIR = Path(__file__).resolve().parents[3] @@ -31,7 +31,6 @@ def _default_oauth_callback_host() -> str: DEFAULT_HOME_DIR = _default_home_dir() -DEFAULT_DB_PATH = DEFAULT_HOME_DIR / "store.db" DEFAULT_ENCRYPTION_KEY_FILE = DEFAULT_HOME_DIR / "encryption.key" @@ -43,14 +42,12 @@ class Settings(BaseSettings): extra="ignore", ) - database_url: str = f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}" + database_url: str = Field(min_length=1) + database_migration_url: str | None = None database_pool_size: int = Field(default=15, gt=0) database_max_overflow: int = Field(default=10, ge=0) database_pool_timeout_seconds: float = Field(default=30.0, gt=0) database_migrate_on_startup: bool = True - database_sqlite_pre_migrate_backup_enabled: bool = True - database_sqlite_pre_migrate_backup_max_files: int = Field(default=5, ge=1) - database_sqlite_startup_check_mode: Literal["quick", "full", "off"] = "quick" database_alembic_auto_remap_enabled: bool = True upstream_base_url: str = "https://chatgpt.com/backend-api" upstream_stream_transport: Literal["http", "websocket", "auto"] = "auto" @@ -97,16 +94,22 @@ class Settings(BaseSettings): default_factory=lambda: ["127.0.0.1/32", "::1/128"] ) - @field_validator("database_url") + @field_validator("database_url", "database_migration_url") @classmethod - def _expand_database_url(cls, value: str) -> str: - for prefix in ("sqlite+aiosqlite:///", "sqlite:///"): - if value.startswith(prefix): - path = value[len(prefix) :] - if path.startswith("~"): - return f"{prefix}{Path(path).expanduser()}" + def _normalize_database_url(cls, value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + if not value: + raise ValueError("database url must not be empty") return value + @model_validator(mode="after") + def _finalize_database_urls(self) -> Settings: + if self.database_migration_url is None: + self.database_migration_url = self.database_url + return self + @field_validator("encryption_key_file", mode="before") @classmethod def _expand_encryption_key_file(cls, value: str | Path) -> Path: @@ -171,4 +174,9 @@ def _validate_upstream_compact_timeout_seconds(cls, value: float | None) -> floa @lru_cache(maxsize=1) def get_settings() -> Settings: - return Settings() + settings = Settings() + if settings.database_migrate_on_startup and not settings.database_migration_url: + raise RuntimeError( + "CODEX_LB_DATABASE_MIGRATION_URL is required when database migrations on startup are enabled" + ) + return settings diff --git a/app/db/alembic/env.py b/app/db/alembic/env.py index ab5d485c..5b8d1406 100644 --- a/app/db/alembic/env.py +++ b/app/db/alembic/env.py @@ -21,7 +21,7 @@ def _sync_database_url() -> str: configured = config.get_main_option("sqlalchemy.url") if configured: return configured - return to_sync_database_url(get_settings().database_url) + return to_sync_database_url(get_settings().database_migration_url or get_settings().database_url) def run_migrations_offline() -> None: @@ -34,7 +34,6 @@ def run_migrations_offline() -> None: literal_binds=True, dialect_opts={"paramstyle": "named"}, compare_type=True, - render_as_batch=url.startswith("sqlite"), ) with context.begin_transaction(): @@ -57,7 +56,6 @@ def run_migrations_online() -> None: connection=connection, target_metadata=target_metadata, compare_type=True, - render_as_batch=connection.dialect.name == "sqlite", ) with context.begin_transaction(): diff --git a/app/db/migrate.py b/app/db/migrate.py index 6cb5c795..489b934c 100644 --- a/app/db/migrate.py +++ b/app/db/migrate.py @@ -564,7 +564,7 @@ def _parse_args() -> argparse.Namespace: parser.add_argument( "--db-url", default=None, - help="Database URL to migrate. Defaults to CODEX_LB_DATABASE_URL from settings.", + help="Database URL to migrate. Defaults to CODEX_LB_DATABASE_MIGRATION_URL or CODEX_LB_DATABASE_URL from settings.", ) subparsers = parser.add_subparsers(dest="command", required=True) @@ -594,7 +594,8 @@ def _parse_args() -> argparse.Namespace: def main() -> None: args = _parse_args() - database_url = args.db_url or get_settings().database_url + settings = get_settings() + database_url = args.db_url or settings.database_migration_url or settings.database_url if args.command == "upgrade": result = run_upgrade( diff --git a/app/db/session.py b/app/db/session.py index c1f15e31..db624327 100644 --- a/app/db/session.py +++ b/app/db/session.py @@ -1,116 +1,51 @@ from __future__ import annotations import logging -import sqlite3 from contextlib import asynccontextmanager -from pathlib import Path -from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Callable, Protocol, TypeVar +from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Callable, TypeVar import anyio from anyio import to_thread -from sqlalchemy import event -from sqlalchemy.engine import Engine +from sqlalchemy.engine import make_url from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config.settings import get_settings -from app.db.sqlite_utils import SqliteIntegrityCheckMode, check_sqlite_integrity, sqlite_db_path_from_url if TYPE_CHECKING: - from app.db.migrate import MigrationRunResult, MigrationState + from app.db.migrate import MigrationRunResult _settings = get_settings() logger = logging.getLogger(__name__) -_SQLITE_BUSY_TIMEOUT_MS = 5_000 -_SQLITE_BUSY_TIMEOUT_SECONDS = _SQLITE_BUSY_TIMEOUT_MS / 1000 +def _runtime_database_url(database_url: str) -> str: + parsed = make_url(database_url) + if parsed.drivername != "postgresql+asyncpg": + return database_url -def _is_sqlite_url(url: str) -> bool: - return url.startswith("sqlite+aiosqlite:///") or url.startswith("sqlite:///") + query = dict(parsed.query) + sslmode = query.pop("sslmode", None) + if sslmode is not None and "ssl" not in query: + query["ssl"] = sslmode + query.pop("channel_binding", None) + normalized = parsed.set(query=query) + return normalized.render_as_string(hide_password=False) -def _is_sqlite_memory_url(url: str) -> bool: - return _is_sqlite_url(url) and ":memory:" in url - - -def _configure_sqlite_engine(engine: Engine, *, enable_wal: bool) -> None: - @event.listens_for(engine, "connect") - def _set_sqlite_pragmas(dbapi_connection: sqlite3.Connection, _: object) -> None: - cursor: sqlite3.Cursor = dbapi_connection.cursor() - try: - if enable_wal: - cursor.execute("PRAGMA journal_mode=WAL") - cursor.execute("PRAGMA synchronous=NORMAL") - cursor.execute("PRAGMA foreign_keys=ON") - cursor.execute(f"PRAGMA busy_timeout={_SQLITE_BUSY_TIMEOUT_MS}") - finally: - cursor.close() - - -if _is_sqlite_url(_settings.database_url): - is_sqlite_memory = _is_sqlite_memory_url(_settings.database_url) - if is_sqlite_memory: - engine = create_async_engine( - _settings.database_url, - echo=False, - connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS}, - ) - else: - engine = create_async_engine( - _settings.database_url, - echo=False, - pool_size=_settings.database_pool_size, - max_overflow=_settings.database_max_overflow, - pool_timeout=_settings.database_pool_timeout_seconds, - connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS}, - ) - _configure_sqlite_engine(engine.sync_engine, enable_wal=not is_sqlite_memory) -else: - engine = create_async_engine( - _settings.database_url, - echo=False, - pool_size=_settings.database_pool_size, - max_overflow=_settings.database_max_overflow, - pool_timeout=_settings.database_pool_timeout_seconds, - ) +engine = create_async_engine( + _runtime_database_url(_settings.database_url), + echo=False, + pool_size=_settings.database_pool_size, + max_overflow=_settings.database_max_overflow, + pool_timeout=_settings.database_pool_timeout_seconds, +) SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) _T = TypeVar("_T") -class _SqliteBackupCreator(Protocol): - def __call__(self, source: Path, *, max_files: int) -> Path: ... - - -def _ensure_sqlite_dir(url: str) -> None: - if not (url.startswith("sqlite+aiosqlite:") or url.startswith("sqlite:")): - return - - marker = ":///" - marker_index = url.find(marker) - if marker_index < 0: - return - - # Works for both relative (sqlite+aiosqlite:///./db.sqlite) and absolute - # paths (sqlite+aiosqlite:////var/lib/app/db.sqlite). - path = url[marker_index + len(marker) :] - path = path.partition("?")[0] - path = path.partition("#")[0] - - if not path or path == ":memory:": - return - - Path(path).expanduser().parent.mkdir(parents=True, exist_ok=True) - - -def _startup_sqlite_check_mode(raw_mode: str) -> SqliteIntegrityCheckMode | None: - if raw_mode == "off": - return None - return SqliteIntegrityCheckMode(raw_mode) - - async def _shielded(awaitable: Awaitable[_T]) -> _T: with anyio.CancelScope(shield=True): return await awaitable @@ -133,19 +68,12 @@ async def _safe_close(session: AsyncSession) -> None: def _load_migration_entrypoints() -> tuple[ - Callable[[str], "MigrationState"], Callable[[str], Awaitable["MigrationRunResult"]], Callable[[str], tuple[str, ...]], ]: - from app.db.migrate import check_schema_drift, inspect_migration_state, run_startup_migrations + from app.db.migrate import check_schema_drift, run_startup_migrations - return inspect_migration_state, run_startup_migrations, check_schema_drift - - -def _load_sqlite_backup_creator() -> _SqliteBackupCreator: - from app.db.backup import create_sqlite_pre_migration_backup - - return create_sqlite_pre_migration_backup + return run_startup_migrations, check_schema_drift @asynccontextmanager @@ -177,42 +105,18 @@ async def get_session() -> AsyncIterator[AsyncSession]: async def init_db() -> None: - _ensure_sqlite_dir(_settings.database_url) - sqlite_path = sqlite_db_path_from_url(_settings.database_url) - if sqlite_path is not None: - check_mode = _startup_sqlite_check_mode(_settings.database_sqlite_startup_check_mode) - if check_mode is not None: - integrity = check_sqlite_integrity(sqlite_path, mode=check_mode) - if not integrity.ok: - details = integrity.details or "unknown error" - pragma_name = "quick_check" if check_mode == SqliteIntegrityCheckMode.QUICK else "integrity_check" - logger.error( - "SQLite %s failed path=%s details=%s", - pragma_name, - sqlite_path, - details, - ) - if "locked" in details.lower(): - message = ( - f"SQLite {pragma_name} failed for {sqlite_path} ({details}). " - "Another instance may be running. Stop it and retry." - ) - else: - message = ( - f"SQLite {pragma_name} failed for {sqlite_path} ({details}). " - "The database appears corrupted or the filesystem is unhealthy. " - "Stop the app and run " - f'`python -m app.db.recover --db "{sqlite_path}" --replace` ' - "or restore a backup from the same directory." - ) - raise RuntimeError(message) - if not _settings.database_migrate_on_startup: logger.info("Startup database migration is disabled") return + migration_url = _settings.database_migration_url + if not migration_url: + raise RuntimeError( + "CODEX_LB_DATABASE_MIGRATION_URL is required when database migrations on startup are enabled" + ) + try: - inspect_migration_state, run_startup_migrations, check_schema_drift = _load_migration_entrypoints() + run_startup_migrations, check_schema_drift = _load_migration_entrypoints() except ModuleNotFoundError as exc: if exc.name != "app.db.migrate": raise @@ -222,33 +126,8 @@ async def init_db() -> None: logger.exception("Failed to import database migration entrypoints from app.db.migrate") raise RuntimeError("Database migration entrypoint app.db.migrate is invalid") from exc - if sqlite_path is not None and _settings.database_sqlite_pre_migrate_backup_enabled and sqlite_path.exists(): - migration_state = await to_thread.run_sync( - lambda: inspect_migration_state(_settings.database_url), - ) - if migration_state.needs_upgrade: - try: - create_sqlite_pre_migration_backup = _load_sqlite_backup_creator() - except ModuleNotFoundError as exc: - if exc.name != "app.db.backup": - raise - logger.exception("Failed to import SQLite backup module=app.db.backup") - raise RuntimeError("SQLite backup module app.db.backup is unavailable") from exc - - backup_path = await to_thread.run_sync( - lambda: create_sqlite_pre_migration_backup( - sqlite_path, - max_files=_settings.database_sqlite_pre_migrate_backup_max_files, - ), - ) - logger.info( - "Created SQLite pre-migration backup path=%s target_revision=%s", - backup_path, - migration_state.head_revision, - ) - try: - result = await run_startup_migrations(_settings.database_url) + result = await run_startup_migrations(migration_url) if result.bootstrap.stamped_revision is not None: logger.info( "Bootstrapped legacy migrations stamped_revision=%s legacy_rows=%s", @@ -257,7 +136,7 @@ async def init_db() -> None: ) if result.current_revision is not None: logger.info("Database migration complete revision=%s", result.current_revision) - drift = await to_thread.run_sync(lambda: check_schema_drift(_settings.database_url)) + drift = await to_thread.run_sync(lambda: check_schema_drift(migration_url)) if drift: drift_details = "; ".join(drift) raise RuntimeError(f"Schema drift detected after startup migrations: {drift_details}") diff --git a/docker-compose.yml b/docker-compose.yml index ae5c6a36..c454d42c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -58,26 +58,6 @@ services: - action: rebuild path: ./frontend/bun.lock - postgres: - image: postgres:16-alpine - profiles: ["postgres"] - environment: - POSTGRES_USER: codex_lb - POSTGRES_PASSWORD: codex_lb - POSTGRES_DB: codex_lb - ports: - - "5432:5432" - healthcheck: - test: ["CMD-SHELL", "pg_isready -U codex_lb -d codex_lb"] - interval: 10s - timeout: 5s - retries: 5 - volumes: - - codex-lb-postgres-data:/var/lib/postgresql/data - restart: unless-stopped - volumes: codex-lb-data: name: codex-lb-data - codex-lb-postgres-data: - name: codex-lb-postgres-data diff --git a/openspec/changes/neon-postgresql-runtime/proposal.md b/openspec/changes/neon-postgresql-runtime/proposal.md new file mode 100644 index 00000000..741a1d73 --- /dev/null +++ b/openspec/changes/neon-postgresql-runtime/proposal.md @@ -0,0 +1,15 @@ +# Neon PostgreSQL Runtime + +## Why +The repo still defaults to SQLite in runtime, tests, and docs, which conflicts with the intended deployment model of using Neon as the primary remote database. + +## What Changes +- Make PostgreSQL on Neon the required runtime backend. +- Add a dedicated migration URL setting for startup migrations and Alembic. +- Remove SQLite-specific startup behavior from the runtime path. +- Update Docker, env templates, and tests to stop assuming SQLite defaults. + +## Impact +- Runtime now fails fast if Neon database URLs are missing. +- Docker Compose no longer provisions local PostgreSQL. +- Existing SQLite files are not migrated; PostgreSQL starts fresh. diff --git a/openspec/changes/neon-postgresql-runtime/specs/database-backends/spec.md b/openspec/changes/neon-postgresql-runtime/specs/database-backends/spec.md new file mode 100644 index 00000000..ccf7b1d0 --- /dev/null +++ b/openspec/changes/neon-postgresql-runtime/specs/database-backends/spec.md @@ -0,0 +1,31 @@ +## MODIFIED Requirements + +### Requirement: PostgreSQL on Neon is the required runtime backend +The service MUST require `CODEX_LB_DATABASE_URL` to be set to a PostgreSQL SQLAlchemy async DSN and MUST fail fast when it is missing. + +#### Scenario: Runtime starts without a database URL +- **WHEN** the service starts without `CODEX_LB_DATABASE_URL` +- **THEN** settings initialization fails with an explicit configuration error + +#### Scenario: Runtime starts with a PostgreSQL URL +- **WHEN** `CODEX_LB_DATABASE_URL` is set to `postgresql+asyncpg://...` +- **THEN** service startup uses PostgreSQL for ORM operations +- **AND** it does not require SQLite path handling or startup validation + +### Requirement: Runtime migrations use a dedicated migration URL +The service MUST accept `CODEX_LB_DATABASE_MIGRATION_URL` as the canonical DSN for Alembic and startup migrations. + +#### Scenario: Migration URL is not set explicitly +- **WHEN** `CODEX_LB_DATABASE_MIGRATION_URL` is unset +- **THEN** the service uses `CODEX_LB_DATABASE_URL` as the migration DSN + +#### Scenario: Startup migrations are enabled without any migration DSN +- **WHEN** startup migrations are enabled and no resolved migration DSN is available +- **THEN** startup fails fast with an explicit configuration error + +### Requirement: Test suite requires explicit PostgreSQL configuration +The test bootstrap MUST honor `CODEX_LB_TEST_DATABASE_URL` as the runtime DSN and `CODEX_LB_TEST_DATABASE_MIGRATION_URL` as the migration DSN without defaulting to SQLite. + +#### Scenario: Tests start without PostgreSQL env +- **WHEN** tests are run without `CODEX_LB_TEST_DATABASE_URL` +- **THEN** test bootstrap fails with an explicit configuration error diff --git a/openspec/changes/neon-postgresql-runtime/specs/database-migrations/spec.md b/openspec/changes/neon-postgresql-runtime/specs/database-migrations/spec.md new file mode 100644 index 00000000..526878f6 --- /dev/null +++ b/openspec/changes/neon-postgresql-runtime/specs/database-migrations/spec.md @@ -0,0 +1,16 @@ +## MODIFIED Requirements + +### Requirement: Alembic startup uses the dedicated migration DSN +The system SHALL run startup migrations, revision inspection, and schema drift checks against the resolved migration DSN rather than the runtime pooled DSN when they differ. + +#### Scenario: Dedicated migration DSN is configured +- **WHEN** `CODEX_LB_DATABASE_MIGRATION_URL` is set +- **THEN** startup migrations, Alembic CLI defaults, and drift checks use that DSN +- **AND** runtime ORM sessions continue using `CODEX_LB_DATABASE_URL` + +### Requirement: SQLite runtime backup flow is not part of PostgreSQL startup +The system SHALL NOT perform SQLite integrity checks or pre-migration backups in the PostgreSQL runtime startup path. + +#### Scenario: Runtime starts on Neon PostgreSQL +- **WHEN** the configured backend is PostgreSQL +- **THEN** startup migration flow skips SQLite-specific backup and integrity tooling diff --git a/openspec/changes/neon-postgresql-runtime/tasks.md b/openspec/changes/neon-postgresql-runtime/tasks.md new file mode 100644 index 00000000..49adc9b0 --- /dev/null +++ b/openspec/changes/neon-postgresql-runtime/tasks.md @@ -0,0 +1,4 @@ +- [x] Update database backend specs and migration specs for Neon-first runtime behavior. +- [x] Change runtime settings and DB session startup to require PostgreSQL URLs and a dedicated migration URL. +- [x] Update Alembic environment, compose, and env templates to use Neon PostgreSQL only. +- [x] Update tests and docs to stop assuming SQLite defaults in runtime bootstrap. diff --git a/openspec/specs/database-backends/context.md b/openspec/specs/database-backends/context.md index be3bf6ca..ecdf6158 100644 --- a/openspec/specs/database-backends/context.md +++ b/openspec/specs/database-backends/context.md @@ -1,33 +1,27 @@ ## Overview -codex-lb is designed to be SQLite-first for simple local usage and container defaults. SQLite-specific resilience behavior (integrity checks, WAL tuning, recovery tooling) remains valuable for the default mode. - -For higher concurrency or infrastructure-managed deployments, PostgreSQL support is enabled through SQLAlchemy async URLs using `asyncpg`. +codex-lb runs against PostgreSQL on Neon for all supported runtime deployments. The runtime contract is split between a pooled application DSN for normal ORM traffic and an optional direct DSN for Alembic and startup migration operations. ## Decisions -- Keep SQLite as default to preserve zero-config startup. -- Accept PostgreSQL through `CODEX_LB_DATABASE_URL` only; no new configuration key aliases. -- Keep SQLite-specific recovery tooling SQLite-only; PostgreSQL operations should use PostgreSQL-native backup/recovery practices. -- Default SQLite startup validation to `quick` so normal boots stay fast while operators can still opt into `full` or `off`. +- Neon PostgreSQL is the required runtime backend. +- `CODEX_LB_DATABASE_URL` is the canonical pooled runtime DSN. +- `CODEX_LB_DATABASE_MIGRATION_URL` is the canonical migration DSN and falls back to `CODEX_LB_DATABASE_URL` when omitted. +- SQLite-specific recovery and validation tooling is no longer part of the runtime startup path. ## Operational Notes -- SQLite default URL: `sqlite+aiosqlite:///~/.codex-lb/store.db` -- SQLite startup check mode: `CODEX_LB_DATABASE_SQLITE_STARTUP_CHECK_MODE=quick|full|off` (default `quick`) -- PostgreSQL example URL: `postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb` -- Pool controls (`database_pool_size`, `database_max_overflow`, `database_pool_timeout_seconds`) apply to non-memory SQLite and PostgreSQL engine creation. +- Runtime DSN example: `postgresql+asyncpg://USER:PASSWORD@ep-pooler.region.aws.neon.tech/codex_lb?sslmode=require` +- Migration DSN example: `postgresql+asyncpg://USER:PASSWORD@ep-direct.region.aws.neon.tech/codex_lb?sslmode=require` +- Pool controls (`database_pool_size`, `database_max_overflow`, `database_pool_timeout_seconds`) apply to the runtime async engine. +- Tests and CI should use a dedicated Neon database or branch via `CODEX_LB_TEST_DATABASE_URL` and `CODEX_LB_TEST_DATABASE_MIGRATION_URL`. ## Example -Use PostgreSQL while keeping all other defaults: - -```bash -CODEX_LB_DATABASE_URL=postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb codex-lb -``` - -Use SQLite with explicit full startup validation: +Use Neon with a pooled runtime connection and a direct migration connection: ```bash -CODEX_LB_DATABASE_SQLITE_STARTUP_CHECK_MODE=full codex-lb +CODEX_LB_DATABASE_URL=postgresql+asyncpg://USER:PASSWORD@ep-pooler.region.aws.neon.tech/codex_lb?sslmode=require \ +CODEX_LB_DATABASE_MIGRATION_URL=postgresql+asyncpg://USER:PASSWORD@ep-direct.region.aws.neon.tech/codex_lb?sslmode=require \ +codex-lb ``` diff --git a/openspec/specs/database-backends/spec.md b/openspec/specs/database-backends/spec.md index b8dd955c..c31daed1 100644 --- a/openspec/specs/database-backends/spec.md +++ b/openspec/specs/database-backends/spec.md @@ -2,63 +2,44 @@ ## Purpose -Define supported database backends and default backend behavior for codex-lb persistence. +Define supported database backends and required runtime database behavior for codex-lb persistence. ## Requirements -### Requirement: SQLite remains the default backend -The service MUST default `CODEX_LB_DATABASE_URL` to a SQLite DSN when no explicit database URL is provided. +### Requirement: PostgreSQL on Neon is the required runtime backend +The service MUST require `CODEX_LB_DATABASE_URL` to be set to a PostgreSQL SQLAlchemy async DSN when the application starts. #### Scenario: No database URL configured - **WHEN** the service starts without `CODEX_LB_DATABASE_URL` -- **THEN** it initializes and runs against the default SQLite database path - -### Requirement: PostgreSQL is supported as an optional backend -The service MUST accept a PostgreSQL SQLAlchemy async DSN (`postgresql+asyncpg://...`) via `CODEX_LB_DATABASE_URL` and initialize SQLAlchemy session/engine wiring without requiring SQLite-specific paths. +- **THEN** settings initialization fails with an explicit configuration error #### Scenario: PostgreSQL URL configured - **WHEN** `CODEX_LB_DATABASE_URL` is set to `postgresql+asyncpg://...` -- **THEN** service startup uses PostgreSQL for ORM operations and migration execution - -### Requirement: SQLite startup validation mode is configurable -The service MUST support configurable startup validation for SQLite file databases via `CODEX_LB_DATABASE_SQLITE_STARTUP_CHECK_MODE`. +- **THEN** service startup uses PostgreSQL for ORM operations +- **AND** it does not perform SQLite-specific startup validation or file-path setup -#### Scenario: Default SQLite startup uses quick validation -- **GIVEN** the configured database URL is a SQLite file -- **AND** `CODEX_LB_DATABASE_SQLITE_STARTUP_CHECK_MODE` is unset -- **WHEN** the service starts -- **THEN** it runs `PRAGMA quick_check` -- **AND** it does not run `PRAGMA integrity_check` +### Requirement: Runtime migrations use a dedicated migration URL +The service MUST accept a dedicated PostgreSQL DSN via `CODEX_LB_DATABASE_MIGRATION_URL` for Alembic and startup migration execution. -#### Scenario: Full SQLite startup validation is explicitly enabled -- **GIVEN** the configured database URL is a SQLite file -- **AND** `CODEX_LB_DATABASE_SQLITE_STARTUP_CHECK_MODE=full` -- **WHEN** the service starts -- **THEN** it runs `PRAGMA integrity_check` +#### Scenario: Dedicated migration URL configured +- **WHEN** `CODEX_LB_DATABASE_MIGRATION_URL` is set to `postgresql+asyncpg://...` +- **THEN** startup migration and Alembic execution use that DSN +- **AND** runtime ORM sessions still use `CODEX_LB_DATABASE_URL` -#### Scenario: SQLite startup validation can be skipped -- **GIVEN** the configured database URL is a SQLite file -- **AND** `CODEX_LB_DATABASE_SQLITE_STARTUP_CHECK_MODE=off` -- **WHEN** the service starts -- **THEN** it skips startup SQLite validation +#### Scenario: Migration URL omitted +- **WHEN** `CODEX_LB_DATABASE_MIGRATION_URL` is not set +- **THEN** the service uses `CODEX_LB_DATABASE_URL` as the migration DSN -### Requirement: Test suite supports backend selection -The test bootstrap MUST allow callers to override `CODEX_LB_DATABASE_URL` via environment and MUST default to SQLite when no override is provided. +### Requirement: Test suite requires explicit PostgreSQL backend configuration +The test bootstrap MUST allow callers to override runtime and migration DSNs via `CODEX_LB_TEST_DATABASE_URL` and `CODEX_LB_TEST_DATABASE_MIGRATION_URL` and MUST NOT silently default to SQLite. -#### Scenario: CI sets PostgreSQL URL -- **WHEN** CI sets `CODEX_LB_DATABASE_URL` to a PostgreSQL DSN +#### Scenario: CI sets PostgreSQL URLs +- **WHEN** CI sets `CODEX_LB_TEST_DATABASE_URL` and `CODEX_LB_TEST_DATABASE_MIGRATION_URL` - **THEN** tests run against PostgreSQL without modifying test code -#### Scenario: Local test run without URL override -- **WHEN** tests are run without setting `CODEX_LB_DATABASE_URL` -- **THEN** tests run against a temporary SQLite database - -### Requirement: CI validates both default and optional backends -CI MUST keep SQLite-backed tests as the default path and MUST run an additional PostgreSQL-backed test job. - -#### Scenario: CI workflow execution -- **WHEN** CI runs on push or pull request -- **THEN** at least one pytest job runs with SQLite and another pytest job runs with PostgreSQL +#### Scenario: Test run without URL override +- **WHEN** tests are run without `CODEX_LB_TEST_DATABASE_URL` +- **THEN** test bootstrap fails with an explicit configuration error ### Requirement: ORM enums persist schema string values ORM enum columns backed by named PostgreSQL enums MUST persist the lowercase string values defined by the schema and migrations, not Python enum member names. diff --git a/openspec/specs/database-migrations/context.md b/openspec/specs/database-migrations/context.md index c465a2fc..4e80a493 100644 --- a/openspec/specs/database-migrations/context.md +++ b/openspec/specs/database-migrations/context.md @@ -42,8 +42,10 @@ ## Operational Notes +- Migration DSN resolution: + - startup and Alembic CLI use `CODEX_LB_DATABASE_MIGRATION_URL` when set, otherwise `CODEX_LB_DATABASE_URL` - Startup path: - - inspect state -> (optional SQLite backup) -> bootstrap legacy `schema_migrations` -> remap legacy Alembic IDs -> `upgrade head` -> schema drift check + - inspect state -> bootstrap legacy `schema_migrations` -> remap legacy Alembic IDs -> `upgrade head` -> schema drift check - CLI checks: - `codex-lb-db check` validates head count, revision naming/filename policy, and schema drift. - Emergency toggle: diff --git a/openspec/specs/database-migrations/spec.md b/openspec/specs/database-migrations/spec.md index 9e410ec6..bc71192e 100644 --- a/openspec/specs/database-migrations/spec.md +++ b/openspec/specs/database-migrations/spec.md @@ -94,18 +94,20 @@ The migration chain SHALL be idempotent for fresh databases and partially migrat - **THEN** schema state remains stable - **AND** the current Alembic revision remains `head` -### Requirement: Automatic SQLite pre-migration backup +### Requirement: Startup migration uses the resolved migration DSN -The system SHALL create a SQLite backup before applying startup migrations when an upgrade is needed. +The system SHALL run startup migrations, revision inspection, and schema drift checks against the resolved migration DSN. -#### Scenario: Startup detects pending migration on SQLite +#### Scenario: Dedicated migration DSN configured -- **GIVEN** the configured database is a SQLite file -- **AND** startup migration is enabled -- **AND** migration state indicates upgrade is required -- **WHEN** startup migration begins -- **THEN** the system creates a pre-migration backup file -- **AND** enforces configured retention on backup files +- **WHEN** `CODEX_LB_DATABASE_MIGRATION_URL` is set +- **THEN** startup migrations, drift checks, and Alembic CLI defaults use that DSN +- **AND** runtime ORM sessions continue using `CODEX_LB_DATABASE_URL` + +#### Scenario: Dedicated migration DSN omitted + +- **WHEN** `CODEX_LB_DATABASE_MIGRATION_URL` is unset +- **THEN** startup migration flow uses `CODEX_LB_DATABASE_URL` ### Requirement: Migration policy and drift guard in CI diff --git a/openspec/specs/query-caching/context.md b/openspec/specs/query-caching/context.md index 348c9fa8..a3e9b974 100644 --- a/openspec/specs/query-caching/context.md +++ b/openspec/specs/query-caching/context.md @@ -1,6 +1,6 @@ ## Overview -The query-caching capability is broader than cache TTLs. It also owns the database query shapes that sit on hot request and dashboard paths, especially when SQLite is the default backend. +The query-caching capability is broader than cache TTLs. It also owns the database query shapes that sit on hot request and dashboard paths, especially for the PostgreSQL-backed runtime used by codex-lb. ## Decisions diff --git a/tests/conftest.py b/tests/conftest.py index 89e8eec8..53f24075 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,6 @@ from __future__ import annotations import os -import tempfile -from pathlib import Path from uuid import uuid4 import pytest @@ -10,11 +8,14 @@ from httpx import ASGITransport, AsyncClient from sqlalchemy import text -TEST_DB_DIR = Path(tempfile.mkdtemp(prefix="codex-lb-tests-")) -TEST_DB_PATH = TEST_DB_DIR / "codex-lb.db" +_runtime_db_url = os.environ.get("CODEX_LB_TEST_DATABASE_URL") +if not _runtime_db_url: + raise RuntimeError("CODEX_LB_TEST_DATABASE_URL must be set for the test suite") -os.environ["CODEX_LB_DATABASE_URL"] = os.environ.get( - "CODEX_LB_TEST_DATABASE_URL", f"sqlite+aiosqlite:///{TEST_DB_PATH}" +os.environ["CODEX_LB_DATABASE_URL"] = _runtime_db_url +os.environ["CODEX_LB_DATABASE_MIGRATION_URL"] = os.environ.get( + "CODEX_LB_TEST_DATABASE_MIGRATION_URL", + _runtime_db_url, ) os.environ["CODEX_LB_UPSTREAM_BASE_URL"] = "https://example.invalid/backend-api" os.environ["CODEX_LB_USAGE_REFRESH_ENABLED"] = "false" @@ -70,8 +71,8 @@ async def async_client(app_instance): @pytest.fixture(autouse=True) -def temp_key_file(monkeypatch): - key_path = TEST_DB_DIR / f"encryption-{uuid4().hex}.key" +def temp_key_file(monkeypatch, tmp_path): + key_path = tmp_path / f"encryption-{uuid4().hex}.key" monkeypatch.setenv("CODEX_LB_ENCRYPTION_KEY_FILE", str(key_path)) from app.core.config.settings import get_settings diff --git a/tests/unit/test_db_session.py b/tests/unit/test_db_session.py index e9adcefb..4626d309 100644 --- a/tests/unit/test_db_session.py +++ b/tests/unit/test_db_session.py @@ -10,28 +10,16 @@ import pytest import app.db.session as session_module -from app.db.sqlite_utils import IntegrityCheck, SqliteIntegrityCheckMode @dataclass(slots=True) class _FakeSettings: database_url: str + database_migration_url: str | None = None database_migrate_on_startup: bool = True - database_sqlite_pre_migrate_backup_enabled: bool = False - database_sqlite_pre_migrate_backup_max_files: int = 5 - database_sqlite_startup_check_mode: str = "quick" database_migrations_fail_fast: bool = False -@dataclass(slots=True) -class _FakeMigrationState: - current_revision: str | None - head_revision: str - has_alembic_version_table: bool - has_legacy_migrations_table: bool - needs_upgrade: bool - - @dataclass(slots=True) class _FakeBootstrap: stamped_revision: str | None = None @@ -44,13 +32,14 @@ class _FakeMigrationRunResult: bootstrap: _FakeBootstrap = field(default_factory=_FakeBootstrap) -def test_import_session_with_sqlite_memory_url_does_not_error() -> None: +def test_import_session_requires_database_url() -> None: repo_root = Path(__file__).resolve().parents[2] env = os.environ.copy() - env["CODEX_LB_DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" + env["CODEX_LB_DATABASE_URL"] = "" + env["CODEX_LB_DATABASE_MIGRATION_URL"] = "" result = subprocess.run( - [sys.executable, "-c", "import sys; import app.db.session; assert 'app.db.migrate' not in sys.modules"], + [sys.executable, "-c", "import app.db.session"], cwd=repo_root, env=env, capture_output=True, @@ -58,13 +47,28 @@ def test_import_session_with_sqlite_memory_url_does_not_error() -> None: check=False, ) - assert result.returncode == 0, result.stderr or result.stdout + assert result.returncode != 0 + assert "database_url" in (result.stderr or result.stdout) + + +def test_runtime_database_url_normalizes_asyncpg_ssl_query_params() -> None: + url = ( + "postgresql+asyncpg://user:pass@host/db?sslmode=require&channel_binding=require&application_name=codex-lb" + ) + + normalized = session_module._runtime_database_url(url) + + assert "ssl=require" in normalized + assert "sslmode=" not in normalized + assert "channel_binding=" not in normalized + assert "application_name=codex-lb" in normalized def test_import_session_with_postgres_url_does_not_error() -> None: repo_root = Path(__file__).resolve().parents[2] env = os.environ.copy() env["CODEX_LB_DATABASE_URL"] = "postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb" + env["CODEX_LB_DATABASE_MIGRATION_URL"] = env["CODEX_LB_DATABASE_URL"] result = subprocess.run( [sys.executable, "-c", "import app.db.session"], @@ -79,104 +83,83 @@ def test_import_session_with_postgres_url_does_not_error() -> None: @pytest.mark.asyncio -async def test_init_db_fails_when_migration_module_is_missing_even_with_fail_fast_disabled(monkeypatch) -> None: - def _raise_missing_migration() -> tuple[object, object]: - raise ModuleNotFoundError("No module named 'app.db.migrate'", name="app.db.migrate") - +async def test_init_db_requires_migration_url_when_startup_migrations_enabled(monkeypatch) -> None: monkeypatch.setattr( session_module, "_settings", - _FakeSettings(database_url="sqlite+aiosqlite:///:memory:", database_migrations_fail_fast=False), + _FakeSettings( + database_url="postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb", + database_migration_url=None, + database_migrate_on_startup=True, + ), ) - monkeypatch.setattr(session_module, "_load_migration_entrypoints", _raise_missing_migration) - with pytest.raises(RuntimeError, match="app\\.db\\.migrate is unavailable"): + with pytest.raises(RuntimeError, match="CODEX_LB_DATABASE_MIGRATION_URL is required"): await session_module.init_db() @pytest.mark.asyncio -async def test_init_db_fails_when_migration_entrypoint_is_invalid_even_with_fail_fast_disabled(monkeypatch) -> None: - def _raise_invalid_migration() -> tuple[object, object]: - raise ImportError("cannot import name 'run_startup_migrations' from 'app.db.migrate'") +async def test_init_db_fails_when_migration_module_is_missing_even_with_fail_fast_disabled(monkeypatch) -> None: + def _raise_missing_migration() -> tuple[object, object]: + raise ModuleNotFoundError("No module named 'app.db.migrate'", name="app.db.migrate") monkeypatch.setattr( session_module, "_settings", - _FakeSettings(database_url="sqlite+aiosqlite:///:memory:", database_migrations_fail_fast=False), + _FakeSettings( + database_url="postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb", + database_migration_url="postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb", + database_migrations_fail_fast=False, + ), ) - monkeypatch.setattr(session_module, "_load_migration_entrypoints", _raise_invalid_migration) + monkeypatch.setattr(session_module, "_load_migration_entrypoints", _raise_missing_migration) - with pytest.raises(RuntimeError, match="app\\.db\\.migrate is invalid"): + with pytest.raises(RuntimeError, match=r"app\.db\.migrate is unavailable"): await session_module.init_db() @pytest.mark.asyncio -async def test_init_db_fails_when_backup_module_is_missing_even_with_fail_fast_disabled(monkeypatch, tmp_path) -> None: - db_path = tmp_path / "store.db" - db_path.write_bytes(b"") - - def _inspect_migration_state(_: str) -> _FakeMigrationState: - return _FakeMigrationState( - current_revision=None, - head_revision="head", - has_alembic_version_table=False, - has_legacy_migrations_table=False, - needs_upgrade=True, - ) - - async def _run_startup_migrations(_: str) -> _FakeMigrationRunResult: - return _FakeMigrationRunResult() - - def _check_schema_drift(_: str) -> tuple[str, ...]: - return () - - def _load_entrypoints() -> tuple[object, object, object]: - return _inspect_migration_state, _run_startup_migrations, _check_schema_drift - - def _raise_missing_backup() -> object: - raise ModuleNotFoundError("No module named 'app.db.backup'", name="app.db.backup") +async def test_init_db_fails_when_migration_entrypoint_is_invalid_even_with_fail_fast_disabled(monkeypatch) -> None: + def _raise_invalid_migration() -> tuple[object, object]: + raise ImportError("cannot import name 'run_startup_migrations' from 'app.db.migrate'") monkeypatch.setattr( session_module, "_settings", _FakeSettings( - database_url=f"sqlite+aiosqlite:///{db_path}", - database_sqlite_pre_migrate_backup_enabled=True, + database_url="postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb", + database_migration_url="postgresql+asyncpg://codex_lb:codex_lb@127.0.0.1:5432/codex_lb", database_migrations_fail_fast=False, ), ) - monkeypatch.setattr(session_module, "_load_migration_entrypoints", _load_entrypoints) - monkeypatch.setattr(session_module, "_load_sqlite_backup_creator", _raise_missing_backup) + monkeypatch.setattr(session_module, "_load_migration_entrypoints", _raise_invalid_migration) - with pytest.raises(RuntimeError, match="app\\.db\\.backup is unavailable"): + with pytest.raises(RuntimeError, match=r"app\.db\.migrate is invalid"): await session_module.init_db() @pytest.mark.asyncio async def test_init_db_fails_fast_on_post_migration_schema_drift(monkeypatch) -> None: - async def _run_startup_migrations(_: str) -> _FakeMigrationRunResult: - return _FakeMigrationRunResult() + seen: list[str] = [] - def _inspect_migration_state(_: str) -> _FakeMigrationState: - return _FakeMigrationState( - current_revision="head", - head_revision="head", - has_alembic_version_table=True, - has_legacy_migrations_table=False, - needs_upgrade=False, - ) + async def _run_startup_migrations(url: str) -> _FakeMigrationRunResult: + seen.append(url) + return _FakeMigrationRunResult() - def _check_schema_drift(_: str) -> tuple[str, ...]: + def _check_schema_drift(url: str) -> tuple[str, ...]: + seen.append(url) return ("('add_table', 'additional_usage_history')",) - def _load_entrypoints() -> tuple[object, object, object]: - return _inspect_migration_state, _run_startup_migrations, _check_schema_drift + def _load_entrypoints() -> tuple[object, object]: + return _run_startup_migrations, _check_schema_drift + migration_url = "postgresql+asyncpg://migrate:migrate@127.0.0.1:5432/codex_lb" monkeypatch.setattr( session_module, "_settings", _FakeSettings( - database_url="sqlite+aiosqlite:///:memory:", + database_url="postgresql+asyncpg://runtime:runtime@127.0.0.1:5432/codex_lb", + database_migration_url=migration_url, database_migrations_fail_fast=True, ), ) @@ -185,32 +168,26 @@ def _load_entrypoints() -> tuple[object, object, object]: with pytest.raises(RuntimeError, match="Schema drift detected after startup migrations"): await session_module.init_db() + assert seen == [migration_url, migration_url] + @pytest.mark.asyncio async def test_init_db_logs_post_migration_schema_drift_when_fail_fast_disabled(monkeypatch, caplog) -> None: async def _run_startup_migrations(_: str) -> _FakeMigrationRunResult: return _FakeMigrationRunResult() - def _inspect_migration_state(_: str) -> _FakeMigrationState: - return _FakeMigrationState( - current_revision="head", - head_revision="head", - has_alembic_version_table=True, - has_legacy_migrations_table=False, - needs_upgrade=False, - ) - def _check_schema_drift(_: str) -> tuple[str, ...]: return ("('missing_index', 'request_logs', 'idx_logs_requested_at_id')",) - def _load_entrypoints() -> tuple[object, object, object]: - return _inspect_migration_state, _run_startup_migrations, _check_schema_drift + def _load_entrypoints() -> tuple[object, object]: + return _run_startup_migrations, _check_schema_drift monkeypatch.setattr( session_module, "_settings", _FakeSettings( - database_url="sqlite+aiosqlite:///:memory:", + database_url="postgresql+asyncpg://runtime:runtime@127.0.0.1:5432/codex_lb", + database_migration_url="postgresql+asyncpg://migrate:migrate@127.0.0.1:5432/codex_lb", database_migrations_fail_fast=False, ), ) @@ -226,75 +203,19 @@ def _load_entrypoints() -> tuple[object, object, object]: @pytest.mark.asyncio -async def test_init_db_uses_quick_check_by_default(monkeypatch, tmp_path) -> None: - db_path = tmp_path / "store.db" - db_path.write_bytes(b"sqlite") - seen: list[SqliteIntegrityCheckMode] = [] - - def _check(path: Path, *, mode: SqliteIntegrityCheckMode = SqliteIntegrityCheckMode.FULL) -> IntegrityCheck: - assert path == db_path - seen.append(mode) - return IntegrityCheck(ok=True, details=None) +async def test_init_db_skips_startup_migration_when_disabled(monkeypatch) -> None: + def _load_entrypoints() -> tuple[object, object]: + raise AssertionError("migration entrypoints should not load when startup migrations are disabled") monkeypatch.setattr( session_module, "_settings", _FakeSettings( - database_url=f"sqlite+aiosqlite:///{db_path}", + database_url="postgresql+asyncpg://runtime:runtime@127.0.0.1:5432/codex_lb", + database_migration_url=None, database_migrate_on_startup=False, ), ) - monkeypatch.setattr(session_module, "check_sqlite_integrity", _check) - - await session_module.init_db() - - assert seen == [SqliteIntegrityCheckMode.QUICK] - - -@pytest.mark.asyncio -async def test_init_db_uses_full_check_when_configured(monkeypatch, tmp_path) -> None: - db_path = tmp_path / "store.db" - db_path.write_bytes(b"sqlite") - seen: list[SqliteIntegrityCheckMode] = [] - - def _check(path: Path, *, mode: SqliteIntegrityCheckMode = SqliteIntegrityCheckMode.FULL) -> IntegrityCheck: - assert path == db_path - seen.append(mode) - return IntegrityCheck(ok=True, details=None) - - monkeypatch.setattr( - session_module, - "_settings", - _FakeSettings( - database_url=f"sqlite+aiosqlite:///{db_path}", - database_migrate_on_startup=False, - database_sqlite_startup_check_mode="full", - ), - ) - monkeypatch.setattr(session_module, "check_sqlite_integrity", _check) - - await session_module.init_db() - - assert seen == [SqliteIntegrityCheckMode.FULL] - - -@pytest.mark.asyncio -async def test_init_db_skips_sqlite_check_when_disabled(monkeypatch, tmp_path) -> None: - db_path = tmp_path / "store.db" - db_path.write_bytes(b"sqlite") - - def _check(_: Path, *, mode: SqliteIntegrityCheckMode = SqliteIntegrityCheckMode.FULL) -> IntegrityCheck: - raise AssertionError("sqlite startup check should be skipped when disabled") - - monkeypatch.setattr( - session_module, - "_settings", - _FakeSettings( - database_url=f"sqlite+aiosqlite:///{db_path}", - database_migrate_on_startup=False, - database_sqlite_startup_check_mode="off", - ), - ) - monkeypatch.setattr(session_module, "check_sqlite_integrity", _check) + monkeypatch.setattr(session_module, "_load_migration_entrypoints", _load_entrypoints) await session_module.init_db() From 972251abce0a35052594ee8fe8a07b75ceb9d755 Mon Sep 17 00:00:00 2001 From: xirothedev Date: Sun, 15 Mar 2026 22:18:53 +0700 Subject: [PATCH 3/7] fix(api-keys): normalize expiration datetimes --- app/modules/api_keys/service.py | 14 ++++-- .../proposal.md | 9 ++++ .../specs/api-keys/spec.md | 16 +++++++ .../fix-api-key-expiration-timezone/tasks.md | 5 +++ openspec/specs/api-keys/spec.md | 16 ++++++- tests/unit/test_api_keys_service.py | 44 ++++++++++++++++++- 6 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 openspec/changes/fix-api-key-expiration-timezone/proposal.md create mode 100644 openspec/changes/fix-api-key-expiration-timezone/specs/api-keys/spec.md create mode 100644 openspec/changes/fix-api-key-expiration-timezone/tasks.md diff --git a/app/modules/api_keys/service.py b/app/modules/api_keys/service.py index e7b43ddc..2fd3731f 100644 --- a/app/modules/api_keys/service.py +++ b/app/modules/api_keys/service.py @@ -13,7 +13,7 @@ calculate_cost_from_usage, get_pricing_for_model, ) -from app.core.utils.time import utcnow +from app.core.utils.time import to_utc_naive, utcnow from app.db.models import ApiKey, ApiKeyLimit, LimitType, LimitWindow from app.modules.api_keys.repository import ( _UNSET, @@ -234,6 +234,7 @@ def __init__(self, repository: ApiKeysRepositoryProtocol) -> None: async def create_key(self, payload: ApiKeyCreateData) -> ApiKeyCreatedData: now = utcnow() + expires_at = _normalize_expires_at(payload.expires_at) plain_key = _generate_plain_key() normalized_allowed_models = _normalize_allowed_models(payload.allowed_models) enforced_model = _normalize_model_slug(payload.enforced_model) @@ -247,7 +248,7 @@ async def create_key(self, payload: ApiKeyCreateData) -> ApiKeyCreatedData: allowed_models=_serialize_allowed_models(normalized_allowed_models), enforced_model=enforced_model, enforced_reasoning_effort=enforced_reasoning_effort, - expires_at=payload.expires_at, + expires_at=expires_at, is_active=True, created_at=now, last_used_at=None, @@ -273,6 +274,7 @@ async def list_keys(self) -> list[ApiKeyData]: ] async def update_key(self, key_id: str, payload: ApiKeyUpdateData) -> ApiKeyData: + expires_at = _normalize_expires_at(payload.expires_at) if payload.expires_at_set else None if payload.allowed_models_set: allowed_models = _normalize_allowed_models(payload.allowed_models) else: @@ -309,7 +311,7 @@ async def update_key(self, key_id: str, payload: ApiKeyUpdateData) -> ApiKeyData allowed_models=_serialize_allowed_models(allowed_models) if payload.allowed_models_set else _UNSET, enforced_model=enforced_model if payload.enforced_model_set else _UNSET, enforced_reasoning_effort=(enforced_reasoning_effort if payload.enforced_reasoning_effort_set else _UNSET), - expires_at=payload.expires_at if payload.expires_at_set else _UNSET, + expires_at=expires_at if payload.expires_at_set else _UNSET, is_active=(payload.is_active if payload.is_active_set and payload.is_active is not None else _UNSET), ) if row is None: @@ -670,6 +672,12 @@ def _normalize_model_slug(value: str | None) -> str | None: _SUPPORTED_REASONING_EFFORTS = frozenset({"none", "minimal", "low", "medium", "high", "xhigh"}) +def _normalize_expires_at(value: datetime | None) -> datetime | None: + if value is None: + return None + return to_utc_naive(value) + + def _normalize_reasoning_effort(value: str | None) -> str | None: if value is None: return None diff --git a/openspec/changes/fix-api-key-expiration-timezone/proposal.md b/openspec/changes/fix-api-key-expiration-timezone/proposal.md new file mode 100644 index 00000000..574c5e35 --- /dev/null +++ b/openspec/changes/fix-api-key-expiration-timezone/proposal.md @@ -0,0 +1,9 @@ +# Proposal + +## Why +Creating or updating an API key with `expiresAt` currently fails against PostgreSQL when the payload uses an ISO 8601 datetime with timezone information. The dashboard sends timezone-aware values (for example `2026-03-20T23:59:59.000Z`), but the backend persists `expires_at` into a `timestamp without time zone` column without normalizing it first, causing asyncpg to reject the write. + +## What Changes +- Normalize API key expiration datetimes to UTC naive before persistence. +- Preserve the public contract that dashboard and API clients may submit ISO 8601 datetimes with timezone offsets for `expiresAt`. +- Add regression coverage for create and update flows with timezone-aware expiration values. diff --git a/openspec/changes/fix-api-key-expiration-timezone/specs/api-keys/spec.md b/openspec/changes/fix-api-key-expiration-timezone/specs/api-keys/spec.md new file mode 100644 index 00000000..187c0379 --- /dev/null +++ b/openspec/changes/fix-api-key-expiration-timezone/specs/api-keys/spec.md @@ -0,0 +1,16 @@ +## MODIFIED Requirements +### Requirement: API Key creation +The system SHALL allow the admin to create API keys via `POST /api/api-keys` with a `name` (required), `allowed_models` (optional list), `weekly_token_limit` (optional integer), and `expires_at` (optional ISO 8601 datetime). The system MUST accept timezone-aware ISO 8601 datetimes for `expiresAt`, normalize them to UTC naive for persistence, and return the expiration as UTC in API responses. + +#### Scenario: Create key with timezone-aware expiration +- **WHEN** admin submits `POST /api/api-keys` with `{ "name": "dev-key", "expiresAt": "2025-12-31T00:00:00Z" }` +- **THEN** the system persists the expiration successfully without PostgreSQL datetime binding errors +- **AND** the response returns `expiresAt` representing the same UTC instant + +### Requirement: API Key update +The system SHALL allow updating key properties via `PATCH /api/api-keys/{id}`. Updatable fields: `name`, `allowedModels`, `weeklyTokenLimit`, `expiresAt`, `isActive`. The key hash and prefix MUST NOT be modifiable. The system MUST accept timezone-aware ISO 8601 datetimes for `expiresAt` and normalize them to UTC naive before persistence. + +#### Scenario: Update key with timezone-aware expiration +- **WHEN** admin submits `PATCH /api/api-keys/{id}` with `{ "expiresAt": "2025-12-31T00:00:00Z" }` +- **THEN** the system persists the expiration successfully without PostgreSQL datetime binding errors +- **AND** the response returns `expiresAt` representing the same UTC instant diff --git a/openspec/changes/fix-api-key-expiration-timezone/tasks.md b/openspec/changes/fix-api-key-expiration-timezone/tasks.md new file mode 100644 index 00000000..c3a53308 --- /dev/null +++ b/openspec/changes/fix-api-key-expiration-timezone/tasks.md @@ -0,0 +1,5 @@ +# Tasks + +- [x] Normalize `expires_at` to UTC naive in the API key service before create and update writes. +- [x] Add regression tests covering timezone-aware expiration datetimes for create and update flows. +- [x] Update the API key spec to state that ISO 8601 expiration datetimes with offsets are accepted and normalized before persistence. diff --git a/openspec/specs/api-keys/spec.md b/openspec/specs/api-keys/spec.md index c9b22aec..039d6b22 100644 --- a/openspec/specs/api-keys/spec.md +++ b/openspec/specs/api-keys/spec.md @@ -5,13 +5,19 @@ TBD - created by archiving change admin-auth-and-api-keys. Update Purpose after ## Requirements ### Requirement: API Key creation -The system SHALL allow the admin to create API keys via `POST /api/api-keys` with a `name` (required), `allowed_models` (optional list), `weekly_token_limit` (optional integer), and `expires_at` (optional ISO 8601 datetime). The system MUST generate a key in the format `sk-clb-{48 hex chars}`, store only the `sha256` hash in the database, and return the plain key exactly once in the creation response. +The system SHALL allow the admin to create API keys via `POST /api/api-keys` with a `name` (required), `allowed_models` (optional list), `weekly_token_limit` (optional integer), and `expires_at` (optional ISO 8601 datetime). The system MUST generate a key in the format `sk-clb-{48 hex chars}`, store only the `sha256` hash in the database, and return the plain key exactly once in the creation response. The system MUST accept timezone-aware ISO 8601 datetimes for `expiresAt`, normalize them to UTC naive for persistence, and return the expiration as UTC in API responses. #### Scenario: Create key with all options - **WHEN** admin submits `POST /api/api-keys` with `{ "name": "dev-key", "allowedModels": ["o3-pro"], "weeklyTokenLimit": 1000000, "expiresAt": "2025-12-31T00:00:00Z" }` - **THEN** the system returns `{ "id": "", "name": "dev-key", "key": "sk-clb-...", "keyPrefix": "sk-clb-a1b2c3d4", "allowedModels": ["o3-pro"], "weeklyTokenLimit": 1000000, "expiresAt": "2025-12-31T00:00:00Z", "createdAt": "..." }` with the plain key visible only in this response +#### Scenario: Create key with timezone-aware expiration + +- **WHEN** admin submits `POST /api/api-keys` with `{ "name": "dev-key", "expiresAt": "2025-12-31T00:00:00Z" }` +- **THEN** the system persists the expiration successfully without PostgreSQL datetime binding errors +- **AND** the response returns `expiresAt` representing the same UTC instant + #### Scenario: Create key with defaults - **WHEN** admin submits `POST /api/api-keys` with `{ "name": "open-key" }` and no optional fields @@ -38,7 +44,7 @@ The system SHALL expose `GET /api/api-keys` returning all API keys with their me ### Requirement: API Key update -The system SHALL allow updating key properties via `PATCH /api/api-keys/{id}`. Updatable fields: `name`, `allowedModels`, `weeklyTokenLimit`, `expiresAt`, `isActive`. The key hash and prefix MUST NOT be modifiable. +The system SHALL allow updating key properties via `PATCH /api/api-keys/{id}`. Updatable fields: `name`, `allowedModels`, `weeklyTokenLimit`, `expiresAt`, `isActive`. The key hash and prefix MUST NOT be modifiable. The system MUST accept timezone-aware ISO 8601 datetimes for `expiresAt` and normalize them to UTC naive before persistence. #### Scenario: Update allowed models @@ -50,6 +56,12 @@ The system SHALL allow updating key properties via `PATCH /api/api-keys/{id}`. U - **WHEN** admin submits `PATCH /api/api-keys/{id}` with `{ "isActive": false }` - **THEN** the key is deactivated; subsequent Bearer requests using this key SHALL be rejected with 401 +#### Scenario: Update key with timezone-aware expiration + +- **WHEN** admin submits `PATCH /api/api-keys/{id}` with `{ "expiresAt": "2025-12-31T00:00:00Z" }` +- **THEN** the system persists the expiration successfully without PostgreSQL datetime binding errors +- **AND** the response returns `expiresAt` representing the same UTC instant + #### Scenario: Update non-existent key - **WHEN** admin submits `PATCH /api/api-keys/{id}` with an unknown ID diff --git a/tests/unit/test_api_keys_service.py b/tests/unit/test_api_keys_service.py index d5b5e757..e45b35b1 100644 --- a/tests/unit/test_api_keys_service.py +++ b/tests/unit/test_api_keys_service.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -20,6 +20,7 @@ ApiKeyRateLimitExceededError, ApiKeysRepositoryProtocol, ApiKeysService, + ApiKeyUpdateData, LimitRuleInput, ) @@ -388,6 +389,26 @@ async def test_create_key_stores_hash_and_prefix() -> None: assert stored.key_prefix == created.key[:15] +@pytest.mark.asyncio +async def test_create_key_normalizes_timezone_aware_expiry_to_utc_naive() -> None: + repo = _FakeApiKeysRepository() + service = ApiKeysService(repo) + + created = await service.create_key( + ApiKeyCreateData( + name="expiring-key", + allowed_models=None, + expires_at=datetime(2026, 3, 20, 23, 59, 59, tzinfo=timezone.utc), + ) + ) + + assert created.expires_at == datetime(2026, 3, 20, 23, 59, 59) + + stored = await repo.get_by_id(created.id) + assert stored is not None + assert stored.expires_at == datetime(2026, 3, 20, 23, 59, 59) + + @pytest.mark.asyncio async def test_create_key_rejects_enforced_model_outside_allowed_models() -> None: repo = _FakeApiKeysRepository() @@ -618,6 +639,27 @@ async def test_enforce_limits_reserves_tier_aware_cost_budget() -> None: assert standard_cost_limit.current_value == 92_159 +@pytest.mark.asyncio +async def test_update_key_normalizes_timezone_aware_expiry_to_utc_naive() -> None: + repo = _FakeApiKeysRepository() + service = ApiKeysService(repo) + created = await service.create_key(ApiKeyCreateData(name="update-expiry", allowed_models=None, expires_at=None)) + + updated = await service.update_key( + created.id, + ApiKeyUpdateData( + expires_at=datetime(2026, 4, 1, 12, 30, 0, tzinfo=timezone.utc), + expires_at_set=True, + ), + ) + + assert updated.expires_at == datetime(2026, 4, 1, 12, 30, 0) + + stored = await repo.get_by_id(created.id) + assert stored is not None + assert stored.expires_at == datetime(2026, 4, 1, 12, 30, 0) + + @pytest.mark.asyncio async def test_regenerate_key_rotates_hash_and_prefix() -> None: repo = _FakeApiKeysRepository() From 72cd7e888585393d6681e782e9e5b6690aca56bc Mon Sep 17 00:00:00 2001 From: xirothedev Date: Mon, 16 Mar 2026 02:08:52 +0700 Subject: [PATCH 4/7] fix(proxy): scope previous response snapshots --- app/core/openai/requests.py | 12 +- .../20260315_120000_add_response_snapshots.py | 19 +- app/db/models.py | 1 + .../proxy/response_snapshots_repository.py | 21 ++- app/modules/proxy/service.py | 41 ++++- .../specs/responses-api-compat/spec.md | 9 +- openspec/specs/responses-api-compat/spec.md | 9 +- tests/integration/test_migrations.py | 36 ++++ .../test_openai_compat_features.py | 94 ++++++++++ .../test_proxy_websocket_responses.py | 163 ++++++++++++++++++ tests/unit/test_proxy_utils.py | 86 ++++++++- 11 files changed, 463 insertions(+), 28 deletions(-) diff --git a/app/core/openai/requests.py b/app/core/openai/requests.py index 85ab82c5..8e074e92 100644 --- a/app/core/openai/requests.py +++ b/app/core/openai/requests.py @@ -116,7 +116,7 @@ def _is_input_file_with_id(item: Mapping[str, JsonValue]) -> bool: return isinstance(file_id, str) and bool(file_id) -def _sanitize_input_items(input_items: list[JsonValue]) -> list[JsonValue]: +def sanitize_input_items(input_items: list[JsonValue]) -> list[JsonValue]: sanitized_input: list[JsonValue] = [] for item in input_items: sanitized_item = _sanitize_interleaved_reasoning_input_item(item) @@ -336,12 +336,12 @@ def _validate_input_type(cls, value: JsonValue) -> JsonValue: normalized = _normalize_input_text(value) if _has_input_file_id(normalized): raise ValueError("input_file.file_id is not supported") - return _sanitize_input_items(normalized) + return sanitize_input_items(normalized) if is_json_list(value): input_items = cast(list[JsonValue], value) if _has_input_file_id(input_items): raise ValueError("input_file.file_id is not supported") - return _sanitize_input_items(input_items) + return sanitize_input_items(input_items) raise ValueError("input must be a string or array") @field_validator("include") @@ -411,12 +411,12 @@ def _validate_input_type(cls, value: JsonValue) -> JsonValue: normalized = _normalize_input_text(value) if _has_input_file_id(normalized): raise ValueError("input_file.file_id is not supported") - return _sanitize_input_items(normalized) + return sanitize_input_items(normalized) if is_json_list(value): input_items = cast(list[JsonValue], value) if _has_input_file_id(input_items): raise ValueError("input_file.file_id is not supported") - return _sanitize_input_items(input_items) + return sanitize_input_items(input_items) raise ValueError("input must be a string or array") @model_validator(mode="before") @@ -471,7 +471,7 @@ def _sanitize_interleaved_reasoning_input(payload: dict[str, JsonValue]) -> None input_items = _json_list_or_none(input_value) if input_items is None: return - payload["input"] = _sanitize_input_items(input_items) + payload["input"] = sanitize_input_items(input_items) def _normalize_openai_compatible_aliases(payload: dict[str, JsonValue]) -> None: diff --git a/app/db/alembic/versions/20260315_120000_add_response_snapshots.py b/app/db/alembic/versions/20260315_120000_add_response_snapshots.py index 21c88514..85dff36a 100644 --- a/app/db/alembic/versions/20260315_120000_add_response_snapshots.py +++ b/app/db/alembic/versions/20260315_120000_add_response_snapshots.py @@ -23,6 +23,13 @@ def _table_exists(connection: Connection, table_name: str) -> bool: return inspector.has_table(table_name) +def _columns(connection: Connection, table_name: str) -> set[str]: + inspector = sa.inspect(connection) + if not inspector.has_table(table_name): + return set() + return {str(column["name"]) for column in inspector.get_columns(table_name) if column.get("name") is not None} + + def _indexes(connection: Connection, table_name: str) -> set[str]: inspector = sa.inspect(connection) if not inspector.has_table(table_name): @@ -38,12 +45,16 @@ def upgrade() -> None: sa.Column("response_id", sa.String(), nullable=False), sa.Column("parent_response_id", sa.String(), nullable=True), sa.Column("account_id", sa.String(), nullable=True), + sa.Column("api_key_id", sa.String(), nullable=True), sa.Column("model", sa.String(), nullable=False), sa.Column("input_items_json", sa.Text(), nullable=False), sa.Column("response_json", sa.Text(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), sa.PrimaryKeyConstraint("response_id"), ) + existing_columns = _columns(bind, "response_snapshots") + if "api_key_id" not in existing_columns: + op.add_column("response_snapshots", sa.Column("api_key_id", sa.String(), nullable=True)) existing_indexes = _indexes(bind, "response_snapshots") if "idx_response_snapshots_parent_created_at" not in existing_indexes: op.create_index( @@ -55,4 +66,10 @@ def upgrade() -> None: def downgrade() -> None: - return + bind = op.get_bind() + if not _table_exists(bind, "response_snapshots"): + return + existing_indexes = _indexes(bind, "response_snapshots") + if "idx_response_snapshots_parent_created_at" in existing_indexes: + op.drop_index("idx_response_snapshots_parent_created_at", table_name="response_snapshots") + op.drop_table("response_snapshots") diff --git a/app/db/models.py b/app/db/models.py index 73d6ee42..05e32e26 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -138,6 +138,7 @@ class ResponseSnapshot(Base): response_id: Mapped[str] = mapped_column(String, primary_key=True) parent_response_id: Mapped[str | None] = mapped_column(String, nullable=True) account_id: Mapped[str | None] = mapped_column(String, nullable=True) + api_key_id: Mapped[str | None] = mapped_column(String, nullable=True) model: Mapped[str] = mapped_column(String, nullable=False) input_items_json: Mapped[str] = mapped_column(Text, nullable=False) response_json: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/app/modules/proxy/response_snapshots_repository.py b/app/modules/proxy/response_snapshots_repository.py index 9665cc5b..49947c20 100644 --- a/app/modules/proxy/response_snapshots_repository.py +++ b/app/modules/proxy/response_snapshots_repository.py @@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import Insert, func +from sqlalchemy.sql import Insert from app.core.types import JsonValue from app.db.models import ResponseSnapshot @@ -16,12 +16,15 @@ class ResponseSnapshotsRepository: def __init__(self, session: AsyncSession) -> None: self._session = session - async def get(self, response_id: str) -> ResponseSnapshot | None: + async def get(self, response_id: str, *, api_key_id: str | None) -> ResponseSnapshot | None: if not response_id: return None - result = await self._session.execute( - select(ResponseSnapshot).where(ResponseSnapshot.response_id == response_id) - ) + statement = select(ResponseSnapshot).where(ResponseSnapshot.response_id == response_id) + if api_key_id is None: + statement = statement.where(ResponseSnapshot.api_key_id.is_(None)) + else: + statement = statement.where(ResponseSnapshot.api_key_id == api_key_id) + result = await self._session.execute(statement) return result.scalar_one_or_none() async def upsert( @@ -30,6 +33,7 @@ async def upsert( response_id: str, parent_response_id: str | None, account_id: str | None, + api_key_id: str | None, model: str, input_items: list[JsonValue], response_payload: dict[str, JsonValue], @@ -38,13 +42,14 @@ async def upsert( response_id=response_id, parent_response_id=parent_response_id, account_id=account_id, + api_key_id=api_key_id, model=model, input_items_json=json.dumps(input_items, ensure_ascii=False, separators=(",", ":")), response_json=json.dumps(response_payload, ensure_ascii=False, separators=(",", ":")), ) await self._session.execute(statement) await self._session.commit() - snapshot = await self.get(response_id) + snapshot = await self.get(response_id, api_key_id=api_key_id) if snapshot is None: raise RuntimeError(f"ResponseSnapshot upsert failed for response_id={response_id!r}") await self._session.refresh(snapshot) @@ -56,6 +61,7 @@ def _build_upsert_statement( response_id: str, parent_response_id: str | None, account_id: str | None, + api_key_id: str | None, model: str, input_items_json: str, response_json: str, @@ -71,6 +77,7 @@ def _build_upsert_statement( response_id=response_id, parent_response_id=parent_response_id, account_id=account_id, + api_key_id=api_key_id, model=model, input_items_json=input_items_json, response_json=response_json, @@ -80,9 +87,9 @@ def _build_upsert_statement( set_={ "parent_response_id": parent_response_id, "account_id": account_id, + "api_key_id": api_key_id, "model": model, "input_items_json": input_items_json, "response_json": response_json, - "created_at": func.now(), }, ) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index c851d422..406532a5 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -52,7 +52,7 @@ from app.core.openai.exceptions import ClientPayloadError from app.core.openai.models import CompactResponsePayload, OpenAIEvent, OpenAIResponsePayload from app.core.openai.parsing import parse_response_payload, parse_sse_event -from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest, _sanitize_input_items +from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest, sanitize_input_items from app.core.types import JsonValue from app.core.usage.types import UsageWindowRow from app.core.utils.request_id import ensure_request_id, get_request_id @@ -840,7 +840,10 @@ async def _prepare_websocket_response_create_request( apply_api_key_enforcement(responses_payload, refreshed_api_key) validate_model_access(refreshed_api_key, responses_payload.model) - resolved_request = await self._resolve_previous_response_request(responses_payload) + resolved_request = await self._resolve_previous_response_request( + responses_payload, + api_key_id=refreshed_api_key.id if refreshed_api_key else None, + ) responses_payload = resolved_request.payload upstream_payload = dict(responses_payload.to_payload()) @@ -864,6 +867,7 @@ async def _prepare_websocket_response_create_request( preferred_account_id=resolved_request.preferred_account_id, current_input_items=resolved_request.current_input_items, awaiting_response_created=True, + api_key_id=refreshed_api_key.id if refreshed_api_key else None, ), affinity_policy=_sticky_key_for_responses_request( responses_payload, @@ -1479,6 +1483,7 @@ async def _finalize_websocket_request_state( response_id=response_id, parent_response_id=request_state.parent_response_id, account_id=account_id_value, + api_key_id=request_state.api_key_id, model=request_state.model or "", input_items=request_state.current_input_items, response_payload=_terminal_response_payload(payload, request_state.output_items), @@ -1859,7 +1864,10 @@ async def _stream_with_retry( base_settings = get_settings() settings = await get_settings_cache().get() deadline = start + base_settings.proxy_request_budget_seconds - resolved_request = await self._resolve_previous_response_request(payload) + resolved_request = await self._resolve_previous_response_request( + payload, + api_key_id=api_key.id if api_key else None, + ) payload = resolved_request.payload prefer_earlier_reset = settings.prefer_earlier_reset_accounts upstream_stream_transport = _resolve_upstream_stream_transport(settings.upstream_stream_transport) @@ -2498,6 +2506,7 @@ async def _stream_once( response_id=response_id, parent_response_id=parent_response_id, account_id=account_id_value, + api_key_id=api_key.id if api_key else None, model=model, input_items=snapshot_input_items, response_payload=completed_response_payload, @@ -2806,13 +2815,21 @@ async def _select_account_with_budget( logger.warning("%s account selection exceeded request budget request_id=%s", kind.title(), request_id) _raise_proxy_budget_exhausted() - async def _resolve_previous_response_request(self, payload: ResponsesRequest) -> _ResolvedResponsesRequest: + async def _resolve_previous_response_request( + self, + payload: ResponsesRequest, + *, + api_key_id: str | None, + ) -> _ResolvedResponsesRequest: current_input_items = _clone_json_list(payload.input) previous_response_id = payload.previous_response_id if not previous_response_id: return _ResolvedResponsesRequest(payload=payload, current_input_items=current_input_items) - replay_items, preferred_account_id = await self._resolve_previous_response_chain(previous_response_id) + replay_items, preferred_account_id = await self._resolve_previous_response_chain( + previous_response_id, + api_key_id=api_key_id, + ) resolved_payload = payload.model_copy( deep=True, update={ @@ -2827,7 +2844,12 @@ async def _resolve_previous_response_request(self, payload: ResponsesRequest) -> preferred_account_id=preferred_account_id, ) - async def _resolve_previous_response_chain(self, response_id: str) -> tuple[list[JsonValue], str | None]: + async def _resolve_previous_response_chain( + self, + response_id: str, + *, + api_key_id: str | None, + ) -> tuple[list[JsonValue], str | None]: if not response_id: _raise_unknown_previous_response_id() @@ -2843,7 +2865,7 @@ async def _resolve_previous_response_chain(self, response_id: str) -> tuple[list if current_id in seen_response_ids: _raise_unknown_previous_response_id() seen_response_ids.add(current_id) - snapshot = await repos.response_snapshots.get(current_id) + snapshot = await repos.response_snapshots.get(current_id, api_key_id=api_key_id) if snapshot is None: _raise_unknown_previous_response_id() chain.append( @@ -2873,6 +2895,7 @@ async def _persist_response_snapshot( response_id: str | None, parent_response_id: str | None, account_id: str | None, + api_key_id: str | None, model: str, input_items: list[JsonValue], response_payload: dict[str, JsonValue] | None, @@ -2893,6 +2916,7 @@ async def _persist_response_snapshot( response_id=response_id, parent_response_id=parent_response_id, account_id=account_id, + api_key_id=api_key_id, model=model, input_items=input_items, response_payload=response_payload, @@ -3003,6 +3027,7 @@ class _WebSocketRequestState: current_input_items: list[JsonValue] = field(default_factory=list) output_items: dict[int, dict[str, JsonValue]] = field(default_factory=dict) awaiting_response_created: bool = False + api_key_id: str | None = None @dataclass(slots=True) @@ -3101,7 +3126,7 @@ def _replayable_response_output_items(response_payload: dict[str, JsonValue]) -> for item in output_value if not (isinstance(item, dict) and item.get("type") == "reasoning") ] - normalized_output = _sanitize_input_items(filtered_output) + normalized_output = sanitize_input_items(filtered_output) replay_items: list[JsonValue] = [] for item in normalized_output: if isinstance(item, dict) and item.get("role") == "assistant": diff --git a/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md b/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md index 07667a6f..47df49ab 100644 --- a/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md +++ b/openspec/changes/support-previous-response-id-persistence/specs/responses-api-compat/spec.md @@ -1,7 +1,7 @@ ## MODIFIED Requirements ### Requirement: Support Responses input types and conversation constraints -The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. When a client supplies `previous_response_id`, the service MUST resolve that id from proxy-managed durable response snapshots, rebuild the prior conversation input/output history as explicit upstream input items, and continue to reject requests that include both `conversation` and `previous_response_id`. +The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. When a client supplies `previous_response_id`, the service MUST resolve that id from proxy-managed durable response snapshots scoped to the current requester, rebuild the prior conversation input/output history as explicit upstream input items, and continue to reject requests that include both `conversation` and `previous_response_id`. #### Scenario: previous response id resolves to replayable history - **WHEN** the client sends `previous_response_id` that matches a persisted prior response snapshot @@ -9,9 +9,14 @@ The service MUST accept `input` as either a string or an array of input items. W - **AND** the current request's `instructions` remain the only top-level instructions forwarded upstream #### Scenario: unknown previous response id -- **WHEN** the client sends `previous_response_id` that does not match a persisted prior response snapshot +- **WHEN** the client sends `previous_response_id` that does not match a persisted prior response snapshot for the current requester - **THEN** the service returns a 400 OpenAI-format error envelope with `param=previous_response_id` +#### Scenario: previous response id belongs to a different API key +- **WHEN** the client sends `previous_response_id` that matches a persisted prior response snapshot for a different API key +- **THEN** the service returns a 400 OpenAI-format error envelope with `param=previous_response_id` +- **AND** the message remains `Unknown previous_response_id` + #### Scenario: conversation and previous response id conflict - **WHEN** the client provides both `conversation` and `previous_response_id` - **THEN** the service returns a 4xx response with an OpenAI error envelope indicating invalid parameters diff --git a/openspec/specs/responses-api-compat/spec.md b/openspec/specs/responses-api-compat/spec.md index cc2a7a02..c1b8681e 100644 --- a/openspec/specs/responses-api-compat/spec.md +++ b/openspec/specs/responses-api-compat/spec.md @@ -16,7 +16,7 @@ The service MUST accept POST requests to `/v1/responses` with a JSON body and MU - **THEN** the service returns a 4xx response with an OpenAI error envelope describing the invalid parameter ### Requirement: Support Responses input types and conversation constraints -The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. When the client supplies `previous_response_id`, the service MUST resolve that id from proxy-managed durable response snapshots, rebuild the prior conversation input/output history as explicit upstream input items, and continue to reject requests that include both `conversation` and `previous_response_id`. +The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. When the client supplies `previous_response_id`, the service MUST resolve that id from proxy-managed durable response snapshots scoped to the current requester, rebuild the prior conversation input/output history as explicit upstream input items, and continue to reject requests that include both `conversation` and `previous_response_id`. #### Scenario: String input - **WHEN** the client sends `input` as a string @@ -27,10 +27,15 @@ The service MUST accept `input` as either a string or an array of input items. W - **THEN** the request is accepted and each item is forwarded in order #### Scenario: previous_response_id resolved from durable snapshots -- **WHEN** the client provides `previous_response_id` that matches a persisted prior response snapshot +- **WHEN** the client provides `previous_response_id` that matches a persisted prior response snapshot for the current requester - **THEN** the service forwards the rebuilt prior input/output history before the current request input - **AND** it does not carry forward prior `instructions` +#### Scenario: previous_response_id exists for another API key +- **WHEN** the client provides `previous_response_id` that matches a persisted prior response snapshot owned by a different API key +- **THEN** the service returns a 400 OpenAI invalid_request_error with `param` set to `previous_response_id` +- **AND** the error message remains `Unknown previous_response_id` + #### Scenario: conversation and previous_response_id conflict - **WHEN** the client provides both `conversation` and `previous_response_id` - **THEN** the service returns a 4xx response with an OpenAI error envelope indicating invalid parameters diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 4d1e1fe6..1fb61873 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import pytest import sqlalchemy as sa from sqlalchemy import text @@ -30,6 +31,7 @@ from app.db.models import Account, AccountStatus from app.db.session import SessionLocal from app.modules.accounts.repository import AccountsRepository +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository try: from app.db.migrate import check_migration_policy @@ -78,6 +80,7 @@ def _inspect_tables(sync_session): "response_id", "parent_response_id", "account_id", + "api_key_id", "model", "input_items_json", "response_json", @@ -85,6 +88,39 @@ def _inspect_tables(sync_session): }.issubset(column_names) +@pytest.mark.asyncio +async def test_response_snapshots_repository_scopes_by_api_key_and_preserves_created_at(db_setup): + async with SessionLocal() as session: + repo = ResponseSnapshotsRepository(session) + first = await repo.upsert( + response_id="resp_scoped", + parent_response_id=None, + account_id="acc_scoped", + api_key_id="key_a", + model="gpt-5.2", + input_items=[{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + response_payload={"id": "resp_scoped", "output": []}, + ) + first_created_at = first.created_at + + await asyncio.sleep(1.1) + + second = await repo.upsert( + response_id="resp_scoped", + parent_response_id=None, + account_id="acc_scoped", + api_key_id="key_a", + model="gpt-5.2", + input_items=[{"role": "user", "content": [{"type": "input_text", "text": "Hello again"}]}], + response_payload={"id": "resp_scoped", "output": []}, + ) + + assert second.created_at == first_created_at + assert await repo.get("resp_scoped", api_key_id="key_a") is not None + assert await repo.get("resp_scoped", api_key_id="key_b") is None + assert await repo.get("resp_scoped", api_key_id=None) is None + + @pytest.mark.asyncio async def test_run_startup_migrations_preserves_unknown_plan_types(db_setup): async with SessionLocal() as session: diff --git a/tests/integration/test_openai_compat_features.py b/tests/integration/test_openai_compat_features.py index a2c1d1bb..88210e08 100644 --- a/tests/integration/test_openai_compat_features.py +++ b/tests/integration/test_openai_compat_features.py @@ -154,6 +154,100 @@ async def fake_stream(payload, headers, access_token, account_id, base_url=None, ] +@pytest.mark.asyncio +async def test_v1_responses_previous_response_id_is_scoped_to_api_key(async_client, app_instance, monkeypatch): + await _import_account(async_client, "acc_prev_scoped", "prev-scoped@example.com") + + enable = await async_client.put( + "/api/settings", + json={ + "stickyThreadsEnabled": False, + "preferEarlierResetAccounts": False, + "totpRequiredOnLogin": False, + "apiKeyAuthEnabled": True, + }, + ) + assert enable.status_code == 200 + + created_a = await async_client.post( + "/api/api-keys/", + json={"name": "response-key-a"}, + ) + assert created_a.status_code == 200 + key_a = created_a.json()["key"] + + created_b = await async_client.post( + "/api/api-keys/", + json={"name": "response-key-b"}, + ) + assert created_b.status_code == 200 + key_b = created_b.json()["key"] + + seen_inputs: list[object] = [] + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kwargs): + seen_inputs.append(payload.input) + if len(seen_inputs) == 1: + yield ( + 'data: {"type":"response.output_item.done","output_index":0,' + '"item":{"id":"msg_prev","type":"message","role":"assistant",' + '"content":[{"type":"output_text","text":"Prior answer"}]}}\n\n' + ) + yield _completed_event("resp_prev_scoped") + return + yield _completed_event("resp_followup_scoped") + + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) + + first = await async_client.post( + "/v1/responses", + headers={"Authorization": f"Bearer {key_a}"}, + json={"model": "gpt-5.2", "input": "Hello"}, + ) + assert first.status_code == 200 + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + second = await async_client.post( + "/v1/responses", + headers={"Authorization": f"Bearer {key_a}"}, + json={ + "model": "gpt-5.2", + "previous_response_id": "resp_prev_scoped", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + ) + assert second.status_code == 200 + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + blocked = await async_client.post( + "/v1/responses", + headers={"Authorization": f"Bearer {key_b}"}, + json={ + "model": "gpt-5.2", + "previous_response_id": "resp_prev_scoped", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + ) + assert blocked.status_code == 400 + error = blocked.json()["error"] + assert error["type"] == "invalid_request_error" + assert error["param"] == "previous_response_id" + assert error["message"] == "Unknown previous_response_id" + + assert seen_inputs == [ + [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, + ], + ] + + @pytest.mark.asyncio @pytest.mark.parametrize( "tool_payload", diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index d7e884ec..3d22bee9 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import base64 import json from collections import deque from types import SimpleNamespace @@ -255,6 +256,168 @@ async def fake_write_request_log(self, **kwargs): ] + +def test_v1_responses_websocket_previous_response_id_is_scoped_to_api_key(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_scoped", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "msg_ws_scoped", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Prior answer"}], + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_scoped", "object": "response", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ] + ) + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + upstreams = deque([first_upstream]) + + async def fake_connect_proxy_websocket( + self, + headers, + sticky_key, + sticky_kind, + prefer_earlier_reset, + reallocate_sticky, + sticky_max_age_seconds, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + preferred_account_id=None, + ): + del self, headers, sticky_key, sticky_kind, prefer_earlier_reset + del reallocate_sticky, sticky_max_age_seconds, routing_strategy, model + del api_key, client_send_lock, websocket, preferred_account_id + return SimpleNamespace(id="acct_ws_scoped"), upstreams.popleft() + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + with TestClient(app_instance) as client: + enable = client.put( + "/api/settings", + json={ + "stickyThreadsEnabled": False, + "preferEarlierResetAccounts": False, + "totpRequiredOnLogin": False, + "apiKeyAuthEnabled": True, + }, + ) + assert enable.status_code == 200 + + imported = client.post( + "/api/accounts/import", + files={ + "auth_json": ( + "auth.json", + json.dumps( + { + "tokens": { + "idToken": "header." + + base64.urlsafe_b64encode( + json.dumps( + { + "email": "ws-scoped@example.com", + "chatgpt_account_id": "acc_ws_scoped", + "https://api.openai.com/auth": {"chatgpt_plan_type": "plus"}, + }, + separators=(",", ":"), + ).encode("utf-8") + ).rstrip(b"=").decode("ascii") + + ".sig", + "accessToken": "access-token", + "refreshToken": "refresh-token", + "accountId": "acc_ws_scoped", + } + }, + separators=(",", ":"), + ), + "application/json", + ) + }, + ) + assert imported.status_code == 200 + + created_a = client.post("/api/api-keys/", json={"name": "ws-key-a"}) + assert created_a.status_code == 200 + key_a = created_a.json()["key"] + + created_b = client.post("/api/api-keys/", json={"name": "ws-key-b"}) + assert created_b.status_code == 200 + key_b = created_b.json()["key"] + + with client.websocket_connect("/v1/responses", headers={"Authorization": f"Bearer {key_a}"}) as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.2", "input": "Hello"})) + assert json.loads(websocket.receive_text())["type"] == "response.created" + assert json.loads(websocket.receive_text())["type"] == "response.output_item.done" + assert json.loads(websocket.receive_text())["type"] == "response.completed" + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + with client.websocket_connect("/v1/responses", headers={"Authorization": f"Bearer {key_b}"}) as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.2", + "previous_response_id": "resp_ws_scoped", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + } + ) + ) + event = json.loads(websocket.receive_text()) + + assert event["type"] == "error" + assert event["status"] == 400 + assert event["error"]["type"] == "invalid_request_error" + assert event["error"]["param"] == "previous_response_id" + assert event["error"]["message"] == "Unknown previous_response_id" + def test_backend_responses_websocket_proxies_upstream_and_persists_log(app_instance, monkeypatch): upstream_messages = [ _FakeUpstreamMessage( diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index 8a0761a4..f34a003e 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -302,8 +302,14 @@ def __init__(self, snapshots: dict[str, object] | None = None) -> None: self._snapshots = snapshots or {} self.upserts: list[dict[str, object]] = [] - async def get(self, response_id: str): - return self._snapshots.get(response_id) + async def get(self, response_id: str, *, api_key_id: str | None): + snapshot = self._snapshots.get(response_id) + if snapshot is None: + return None + snapshot_api_key_id = getattr(snapshot, "api_key_id", None) + if snapshot_api_key_id != api_key_id: + return None + return snapshot async def upsert(self, **kwargs: object) -> object: self.upserts.append(dict(kwargs)) @@ -311,6 +317,7 @@ async def upsert(self, **kwargs: object) -> object: response_id=kwargs["response_id"], parent_response_id=kwargs.get("parent_response_id"), account_id=kwargs.get("account_id"), + api_key_id=kwargs.get("api_key_id"), model=kwargs["model"], input_items_json=json.dumps(kwargs["input_items"], separators=(",", ":")), response_json=json.dumps(kwargs["response_payload"], separators=(",", ":")), @@ -2886,6 +2893,81 @@ async def test_prepare_websocket_response_create_request_rebuilds_previous_respo ] +@pytest.mark.asyncio +async def test_prepare_websocket_response_create_request_rejects_previous_response_from_other_api_key(monkeypatch): + request_logs = _RequestLogsRecorder() + response_snapshots = _ResponseSnapshotsStub( + { + "resp_prev_other_key": SimpleNamespace( + response_id="resp_prev_other_key", + parent_response_id=None, + account_id="acc_prev", + api_key_id="key_a", + input_items_json=json.dumps( + [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + separators=(",", ":"), + ), + response_json=json.dumps( + { + "id": "resp_prev_other_key", + "output": [ + { + "id": "msg_prev", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Prior answer"}], + } + ], + }, + separators=(",", ":"), + ), + ) + } + ) + service = proxy_service.ProxyService(_repo_factory(request_logs, response_snapshots=response_snapshots)) + + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", AsyncMock(return_value=None)) + monkeypatch.setattr( + service, + "_refresh_websocket_api_key_policy", + AsyncMock( + return_value=ApiKeyData( + id="key_b", + name="key-b", + key_prefix="sk-test", + allowed_models=None, + enforced_model=None, + enforced_reasoning_effort=None, + expires_at=None, + is_active=True, + created_at=utcnow(), + last_used_at=None, + ) + ), + ) + + with pytest.raises(proxy_module.ProxyResponseError) as exc_info: + await service._prepare_websocket_response_create_request( + { + "type": "response.create", + "model": "gpt-5.2", + "previous_response_id": "resp_prev_other_key", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + headers={}, + codex_session_affinity=False, + openai_cache_affinity=True, + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=300, + api_key=None, + ) + + error = _assert_proxy_response_error(exc_info.value) + assert error.status_code == 400 + assert error.payload["error"]["param"] == "previous_response_id" + assert error.payload["error"]["message"] == "Unknown previous_response_id" + + @pytest.mark.asyncio async def test_prepare_websocket_response_create_request_normalizes_payload_and_reserves_forwarded_tier(monkeypatch): request_logs = _RequestLogsRecorder() From 948b486e1b18c1bb42781a90b3bb53a52cdb083e Mon Sep 17 00:00:00 2001 From: xirothedev Date: Mon, 16 Mar 2026 09:47:58 +0700 Subject: [PATCH 5/7] fix(proxy): retry websocket stream setup once --- app/modules/proxy/service.py | 186 ++++++++++++++++-- .../test_proxy_websocket_responses.py | 132 +++++++++++++ 2 files changed, 307 insertions(+), 11 deletions(-) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 406532a5..f5c43870 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -1,12 +1,13 @@ from __future__ import annotations import asyncio +import contextlib import inspect import json import logging import time from collections import deque -from collections.abc import Sequence +from collections.abc import Collection, Sequence from dataclasses import dataclass, field from hashlib import sha256 from typing import AsyncIterator, Mapping, NoReturn @@ -764,6 +765,9 @@ async def proxy_responses_websocket( response_create_gate=response_create_gate, proxy_request_budget_seconds=runtime_settings.proxy_request_budget_seconds, stream_idle_timeout_seconds=runtime_settings.stream_idle_timeout_seconds, + filtered_headers=filtered_headers, + prefer_earlier_reset=prefer_earlier_reset, + routing_strategy=routing_strategy, ) ) @@ -853,9 +857,18 @@ async def _prepare_websocket_response_create_request( request_model=responses_payload.model, request_service_tier=forwarded_service_tier, ) + affinity_policy = _sticky_key_for_responses_request( + responses_payload, + headers, + codex_session_affinity=codex_session_affinity, + openai_cache_affinity=openai_cache_affinity, + openai_cache_affinity_max_age_seconds=openai_cache_affinity_max_age_seconds, + sticky_threads_enabled=sticky_threads_enabled, + ) + text_data = json.dumps(upstream_payload, ensure_ascii=True, separators=(",", ":")) return _PreparedWebSocketRequest( - text_data=json.dumps(upstream_payload, ensure_ascii=True, separators=(",", ":")), + text_data=text_data, request_state=_WebSocketRequestState( request_id=f"ws_{uuid4().hex}", model=responses_payload.model, @@ -868,15 +881,13 @@ async def _prepare_websocket_response_create_request( current_input_items=resolved_request.current_input_items, awaiting_response_created=True, api_key_id=refreshed_api_key.id if refreshed_api_key else None, + upstream_text_data=text_data, + affinity_key=affinity_policy.key, + affinity_kind=affinity_policy.kind, + affinity_reallocate_sticky=affinity_policy.reallocate_sticky, + affinity_max_age_seconds=affinity_policy.max_age_seconds, ), - affinity_policy=_sticky_key_for_responses_request( - responses_payload, - headers, - codex_session_affinity=codex_session_affinity, - openai_cache_affinity=openai_cache_affinity, - openai_cache_affinity_max_age_seconds=openai_cache_affinity_max_age_seconds, - sticky_threads_enabled=sticky_threads_enabled, - ), + affinity_policy=affinity_policy, ) async def _connect_proxy_websocket( @@ -1175,6 +1186,9 @@ async def _relay_upstream_websocket_messages( response_create_gate: asyncio.Semaphore, proxy_request_budget_seconds: float, stream_idle_timeout_seconds: float, + filtered_headers: dict[str, str], + prefer_earlier_reset: bool, + routing_strategy: RoutingStrategy, ) -> None: try: while True: @@ -1255,12 +1269,36 @@ async def _relay_upstream_websocket_messages( async with client_send_lock: await websocket.send_bytes(message.data) continue + disconnect_message = _upstream_websocket_disconnect_message(message) + logger.warning( + "Upstream websocket disconnected before terminal event account_id=%s request_id=%s kind=%s close_code=%s error=%s", + account_id_value, + get_request_id(), + message.kind, + message.close_code, + message.error, + ) + await self._handle_stream_error(account, {"message": disconnect_message}, "stream_incomplete") + retried = await self._retry_websocket_request_after_disconnect( + disconnected_account=account, + pending_requests=pending_requests, + pending_lock=pending_lock, + filtered_headers=filtered_headers, + prefer_earlier_reset=prefer_earlier_reset, + routing_strategy=routing_strategy, + proxy_request_budget_seconds=proxy_request_budget_seconds, + disconnect_event=message, + ) + if retried is not None: + account, upstream = retried + account_id_value = account.id + continue await self._fail_pending_websocket_requests( account_id_value=account_id_value, pending_requests=pending_requests, pending_lock=pending_lock, error_code="stream_incomplete", - error_message=_upstream_websocket_disconnect_message(message), + error_message=disconnect_message, api_key=api_key, websocket=websocket, client_send_lock=client_send_lock, @@ -1277,6 +1315,115 @@ async def _relay_upstream_websocket_messages( except Exception: logger.debug("Failed to close downstream websocket", exc_info=True) + async def _retry_websocket_request_after_disconnect( + self, + *, + disconnected_account: Account, + pending_requests: deque[_WebSocketRequestState], + pending_lock: anyio.Lock, + filtered_headers: dict[str, str], + prefer_earlier_reset: bool, + routing_strategy: RoutingStrategy, + proxy_request_budget_seconds: float, + disconnect_event: UpstreamWebSocketMessage, + ) -> tuple[Account, UpstreamResponsesWebSocket] | None: + async with pending_lock: + retry_candidates = [ + request_state for request_state in pending_requests if _is_retryable_websocket_request_state(request_state) + ] + if len(retry_candidates) != 1 or len(pending_requests) != 1: + return None + request_state = retry_candidates[0] + request_state.websocket_retry_count += 1 + + logger.warning( + "Retrying websocket request after upstream disconnect request_id=%s failed_account_id=%s kind=%s close_code=%s error=%s retry_count=%s", + request_state.request_id, + disconnected_account.id, + disconnect_event.kind, + disconnect_event.close_code, + disconnect_event.error, + request_state.websocket_retry_count, + ) + + deadline = request_state.started_at + proxy_request_budget_seconds + try: + selection = await self._select_account_with_budget( + deadline, + request_id=request_state.request_id, + kind="websocket", + sticky_key=request_state.affinity_key, + sticky_kind=request_state.affinity_kind, + reallocate_sticky=request_state.affinity_reallocate_sticky, + sticky_max_age_seconds=request_state.affinity_max_age_seconds, + prefer_earlier_reset_accounts=prefer_earlier_reset, + routing_strategy=routing_strategy, + model=request_state.model, + preferred_account_id=request_state.preferred_account_id, + exclude_account_ids={disconnected_account.id}, + ) + except Exception: + logger.warning( + "Failed to select retry account after websocket disconnect request_id=%s failed_account_id=%s", + request_state.request_id, + disconnected_account.id, + exc_info=True, + ) + return None + + retry_account = selection.account + if retry_account is None: + logger.warning( + "No retry account available after websocket disconnect request_id=%s failed_account_id=%s error=%s", + request_state.request_id, + disconnected_account.id, + selection.error_message, + ) + return None + + retry_upstream: UpstreamResponsesWebSocket | None = None + try: + remaining_budget = _remaining_budget_seconds(deadline) + if remaining_budget <= 0: + return None + retry_account = await self._ensure_fresh_with_budget(retry_account, timeout_seconds=remaining_budget) + remaining_budget = _remaining_budget_seconds(deadline) + if remaining_budget <= 0: + return None + retry_upstream = await self._open_upstream_websocket_with_budget( + retry_account, + filtered_headers, + timeout_seconds=remaining_budget, + ) + if request_state.upstream_text_data is not None: + await retry_upstream.send_text(request_state.upstream_text_data) + elif request_state.upstream_bytes_data is not None: + await retry_upstream.send_bytes(request_state.upstream_bytes_data) + else: + await retry_upstream.close() + return None + except Exception: + logger.warning( + "Websocket retry failed request_id=%s failed_account_id=%s retry_account_id=%s", + request_state.request_id, + disconnected_account.id, + retry_account.id, + exc_info=True, + ) + await self._load_balancer.record_error(retry_account) + if retry_upstream is not None: + with contextlib.suppress(Exception): + await retry_upstream.close() + return None + + logger.info( + "Retried websocket request request_id=%s failed_account_id=%s retry_account_id=%s", + request_state.request_id, + disconnected_account.id, + retry_account.id, + ) + return retry_account, retry_upstream + async def _process_upstream_websocket_text( self, text: str, @@ -2791,6 +2938,7 @@ async def _select_account_with_budget( model: str | None = None, additional_limit_name: str | None = None, preferred_account_id: str | None = None, + exclude_account_ids: Collection[str] | None = None, ) -> AccountSelection: remaining_budget = _remaining_budget_seconds(deadline) if remaining_budget <= 0: @@ -2810,6 +2958,7 @@ async def _select_account_with_budget( model=model, additional_limit_name=additional_limit_name, preferred_account_id=preferred_account_id, + exclude_account_ids=exclude_account_ids, ) except TimeoutError: logger.warning("%s account selection exceeded request budget request_id=%s", kind.title(), request_id) @@ -3028,6 +3177,13 @@ class _WebSocketRequestState: output_items: dict[int, dict[str, JsonValue]] = field(default_factory=dict) awaiting_response_created: bool = False api_key_id: str | None = None + upstream_text_data: str | None = None + upstream_bytes_data: bytes | None = None + affinity_key: str | None = None + affinity_kind: StickySessionKind | None = None + affinity_reallocate_sticky: bool = False + affinity_max_age_seconds: int | None = None + websocket_retry_count: int = 0 @dataclass(slots=True) @@ -3235,6 +3391,14 @@ def _pop_terminal_websocket_request_state( return None +def _is_retryable_websocket_request_state(request_state: _WebSocketRequestState) -> bool: + if request_state.response_id is not None: + return False + if request_state.websocket_retry_count >= 1: + return False + return request_state.upstream_text_data is not None or request_state.upstream_bytes_data is not None + + def _upstream_websocket_disconnect_message(message: UpstreamWebSocketMessage) -> str: if message.kind == "error" and message.error: return f"Upstream websocket closed before response.completed: {message.error}" diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 3d22bee9..fc7434aa 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1578,6 +1578,131 @@ async def fake_write_request_log(self, **kwargs): assert log_calls[1]["input_tokens"] == 3 +def test_backend_responses_websocket_retries_once_after_upstream_eof_before_response_created(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket([_FakeUpstreamMessage("close", close_code=1011)]) + second_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_ws_retry", "status": "completed"}}, + separators=(",", ":"), + ), + ), + ] + ) + connect_accounts: list[str] = [] + stream_errors: list[tuple[str, str]] = [] + log_calls: list[dict[str, object]] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key=None, + sticky_kind=None, + reallocate_sticky=False, + sticky_max_age_seconds=None, + prefer_earlier_reset_accounts=False, + routing_strategy="usage_weighted", + model=None, + additional_limit_name=None, + preferred_account_id=None, + exclude_account_ids=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + additional_limit_name, + preferred_account_id, + ) + if exclude_account_ids: + assert exclude_account_ids == {"acct_ws_first"} + return SimpleNamespace(account=SimpleNamespace(id="acct_ws_retry"), error_message=None, error_code=None) + return SimpleNamespace(account=SimpleNamespace(id="acct_ws_first"), error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, account, *, force=False, timeout_seconds=None): + del self, force, timeout_seconds + return account + + async def fake_open_upstream_websocket_with_budget(self, account, headers, *, timeout_seconds): + del self, headers, timeout_seconds + connect_accounts.append(account.id) + if account.id == "acct_ws_first": + return first_upstream + return second_upstream + + async def fake_write_request_log(self, **kwargs): + del self + log_calls.append(kwargs) + + async def fake_handle_stream_error(self, account, error, code): + del error + stream_errors.append((account.id, code)) + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_open_upstream_websocket_with_budget", fake_open_upstream_websocket_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.4", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + created_event = json.loads(websocket.receive_text()) + completed_event = json.loads(websocket.receive_text()) + + assert [created_event["type"], completed_event["type"]] == ["response.created", "response.completed"] + assert connect_accounts == ["acct_ws_first", "acct_ws_retry"] + assert stream_errors == [("acct_ws_first", "stream_incomplete")] + assert len(first_upstream.sent_text) == 1 + assert len(second_upstream.sent_text) == 1 + assert json.loads(first_upstream.sent_text[0])["type"] == "response.create" + assert json.loads(second_upstream.sent_text[0])["type"] == "response.create" + assert len(log_calls) == 1 + assert log_calls[0]["request_id"] == "resp_ws_retry" + assert log_calls[0]["status"] == "success" + + def test_backend_responses_websocket_emits_response_failed_before_close_on_upstream_eof(app_instance, monkeypatch): upstream_messages = [ _FakeUpstreamMessage( @@ -1591,6 +1716,7 @@ def test_backend_responses_websocket_emits_response_failed_before_close_on_upstr ] fake_upstream = _FakeUpstreamWebSocket(upstream_messages) log_calls: list[dict[str, object]] = [] + stream_errors: list[tuple[str, str]] = [] class _FakeSettingsCache: async def get(self): @@ -1639,11 +1765,16 @@ async def fake_write_request_log(self, **kwargs): del self log_calls.append(kwargs) + async def fake_handle_stream_error(self, account, error, code): + del self, error + stream_errors.append((account.id, code)) + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) request_payload = { "type": "response.create", @@ -1668,3 +1799,4 @@ async def fake_write_request_log(self, **kwargs): assert log_calls[0]["request_id"] == "resp_ws_eof" assert log_calls[0]["status"] == "error" assert log_calls[0]["error_code"] == "stream_incomplete" + assert stream_errors == [("acct_ws_proxy", "stream_incomplete")] From a80159f1d9a5d535307eabd435d1625a6600531b Mon Sep 17 00:00:00 2001 From: xirothedev Date: Mon, 16 Mar 2026 09:58:19 +0700 Subject: [PATCH 6/7] ci: add branch image workflow for vps deploy --- .github/workflows/branch-image.yml | 93 ++++++++++++++++++++++++++++++ docker-compose.vps.yml | 15 +++++ 2 files changed, 108 insertions(+) create mode 100644 .github/workflows/branch-image.yml create mode 100644 docker-compose.vps.yml diff --git a/.github/workflows/branch-image.yml b/.github/workflows/branch-image.yml new file mode 100644 index 00000000..34999552 --- /dev/null +++ b/.github/workflows/branch-image.yml @@ -0,0 +1,93 @@ +name: Branch Docker Image + +on: + workflow_dispatch: + inputs: + ref: + description: "Git ref to build (branch, tag, or SHA)" + required: true + image_tag: + description: "Optional image tag override" + required: false + platforms: + description: "Comma-separated target platforms" + required: false + default: "linux/amd64" + +concurrency: + group: ${{ github.workflow }}-${{ inputs.ref }} + cancel-in-progress: true + +jobs: + docker-image: + name: Build and push branch image + runs-on: ubuntu-24.04 + + permissions: + contents: read + packages: write + + outputs: + image: ${{ steps.image.outputs.image }} + tag: ${{ steps.image.outputs.tag }} + sha_tag: ${{ steps.image.outputs.sha_tag }} + + steps: + - name: Checkout repository + uses: actions/checkout@v6.0.2 + with: + ref: ${{ inputs.ref }} + + - name: Compute image tag + id: image + shell: bash + env: + INPUT_IMAGE_TAG: ${{ inputs.image_tag }} + run: | + set -euo pipefail + + short_sha="$(git rev-parse --short=12 HEAD)" + ref_name="$(git rev-parse --abbrev-ref HEAD)" + if [ "${ref_name}" = "HEAD" ]; then + ref_name="${short_sha}" + fi + + slug="$(printf '%s' "${ref_name}" | tr '[:upper:]' '[:lower:]' | sed -E 's#[^a-z0-9._-]+#-#g; s#-+#-#g; s#(^[-.]+|[-.]+$)##g')" + + tag="${INPUT_IMAGE_TAG:-${slug}-${short_sha}}" + image="ghcr.io/${GITHUB_REPOSITORY}:${tag}" + + echo "tag=${tag}" >> "${GITHUB_OUTPUT}" + echo "sha_tag=sha-${short_sha}" >> "${GITHUB_OUTPUT}" + echo "image=${image}" >> "${GITHUB_OUTPUT}" + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + file: Dockerfile + push: true + platforms: ${{ inputs.platforms }} + tags: | + ${{ steps.image.outputs.image }} + ghcr.io/${{ github.repository }}:${{ steps.image.outputs.sha_tag }} + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Summarize image + run: | + echo "Built image: ${{ steps.image.outputs.image }}" >> "${GITHUB_STEP_SUMMARY}" + echo "Extra tag: ghcr.io/${{ github.repository }}:${{ steps.image.outputs.sha_tag }}" >> "${GITHUB_STEP_SUMMARY}" diff --git a/docker-compose.vps.yml b/docker-compose.vps.yml new file mode 100644 index 00000000..b971ec1c --- /dev/null +++ b/docker-compose.vps.yml @@ -0,0 +1,15 @@ +services: + server: + image: ${CODEX_LB_IMAGE:-ghcr.io/soju06/codex-lb:latest} + env_file: + - .env.local + ports: + - "2455:2455" + - "1455:1455" + volumes: + - codex-lb-data:/var/lib/codex-lb + restart: unless-stopped + +volumes: + codex-lb-data: + name: codex-lb-data From 995b4d510f4b54d315a6cfe68811be3cdcb2f4fa Mon Sep 17 00:00:00 2001 From: xirothedev Date: Mon, 16 Mar 2026 11:09:00 +0700 Subject: [PATCH 7/7] docs(openspec): avoid sourcing env files in ops probe --- openspec/specs/responses-api-compat/ops.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/openspec/specs/responses-api-compat/ops.md b/openspec/specs/responses-api-compat/ops.md index c49d58e7..d8b78e68 100644 --- a/openspec/specs/responses-api-compat/ops.md +++ b/openspec/specs/responses-api-compat/ops.md @@ -49,8 +49,9 @@ This probe bypasses `codex-lb` selection and measures what the upstream returns Run: ```bash -set -a && source /home/egor/services/codex-lb-defin85/.env.local && set +a && cd /home/egor/services/codex-lb-defin85 && .venv/bin/python - <<'PY' -import asyncio, json +cd /home/egor/services/codex-lb-defin85 && .venv/bin/python - <<'PY' +import asyncio, json, os +from dotenv import dotenv_values from sqlalchemy import select from app.core.clients.http import close_http_client, init_http_client @@ -60,6 +61,7 @@ from app.db.models import Account from app.db.session import SessionLocal EMAIL = "TARGET_EMAIL" +os.environ.update({key: value for key, value in dotenv_values(".env.local").items() if value is not None}) async def main(): await init_http_client() @@ -122,6 +124,9 @@ asyncio.run(main()) PY ``` +Do not `source` `.env.local` directly in shell when URLs may include characters like `&`. +Load it with `python-dotenv` or quote those values first. + Expected useful output: ```json