Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions app/core/openai/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def _sanitize_input_items(input_items: list[JsonValue]) -> list[JsonValue]:
return sanitized_input


def sanitize_input_items(input_items: list[JsonValue]) -> list[JsonValue]:
return _sanitize_input_items(input_items)


def _sanitize_interleaved_reasoning_input_item(item: JsonValue) -> JsonValue | None:
item_mapping = _json_mapping_or_none(item)
if item_mapping is None:
Expand Down
75 changes: 75 additions & 0 deletions app/db/alembic/versions/20260327_000000_add_response_snapshots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""add durable response snapshots

Revision ID: 20260327_000000_add_response_snapshots
Revises: 20260321_210000_merge_request_log_tiers_and_dashboard_index_heads
Create Date: 2026-03-27 00:00:00.000000
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine import Connection

# revision identifiers, used by Alembic.
revision = "20260327_000000_add_response_snapshots"
down_revision = "20260321_210000_merge_request_log_tiers_and_dashboard_index_heads"
branch_labels = None
depends_on = None


def _table_exists(connection: Connection, table_name: str) -> bool:
inspector = sa.inspect(connection)
return inspector.has_table(table_name)


def _columns(connection: Connection, table_name: str) -> set[str]:
inspector = sa.inspect(connection)
if not inspector.has_table(table_name):
return set()
return {str(column["name"]) for column in inspector.get_columns(table_name) if column.get("name") is not None}


def _indexes(connection: Connection, table_name: str) -> set[str]:
inspector = sa.inspect(connection)
if not inspector.has_table(table_name):
return set()
return {str(index["name"]) for index in inspector.get_indexes(table_name) if index.get("name") is not None}


def upgrade() -> None:
bind = op.get_bind()
if not _table_exists(bind, "response_snapshots"):
op.create_table(
"response_snapshots",
sa.Column("response_id", sa.String(), nullable=False),
sa.Column("parent_response_id", sa.String(), nullable=True),
sa.Column("account_id", sa.String(), nullable=True),
sa.Column("api_key_id", sa.String(), nullable=True),
sa.Column("model", sa.String(), nullable=False),
sa.Column("input_items_json", sa.Text(), nullable=False),
sa.Column("response_json", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint("response_id"),
)
existing_columns = _columns(bind, "response_snapshots")
if "api_key_id" not in existing_columns:
op.add_column("response_snapshots", sa.Column("api_key_id", sa.String(), nullable=True))
existing_indexes = _indexes(bind, "response_snapshots")
if "idx_response_snapshots_parent_created_at" not in existing_indexes:
op.create_index(
"idx_response_snapshots_parent_created_at",
"response_snapshots",
["parent_response_id", "created_at"],
unique=False,
)


def downgrade() -> None:
bind = op.get_bind()
if not _table_exists(bind, "response_snapshots"):
return
existing_indexes = _indexes(bind, "response_snapshots")
if "idx_response_snapshots_parent_created_at" in existing_indexes:
op.drop_index("idx_response_snapshots_parent_created_at", table_name="response_snapshots")
op.drop_table("response_snapshots")
16 changes: 16 additions & 0 deletions app/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ class RequestLog(Base):
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)


class ResponseSnapshot(Base):
__tablename__ = "response_snapshots"
__table_args__ = (
Index("idx_response_snapshots_parent_created_at", "parent_response_id", "created_at"),
)

response_id: Mapped[str] = mapped_column(String, primary_key=True)
parent_response_id: Mapped[str | None] = mapped_column(String, nullable=True)
account_id: Mapped[str | None] = mapped_column(String, nullable=True)
api_key_id: Mapped[str | None] = mapped_column(String, nullable=True)
model: Mapped[str] = mapped_column(String, nullable=False)
input_items_json: Mapped[str] = mapped_column(Text, nullable=False)
response_json: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)


class StickySession(Base):
__tablename__ = "sticky_sessions"

Expand Down
2 changes: 2 additions & 0 deletions app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from app.modules.firewall.service import FirewallService
from app.modules.oauth.service import OauthService
from app.modules.proxy.repo_bundle import ProxyRepositories
from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository
from app.modules.proxy.service import ProxyService
from app.modules.proxy.sticky_repository import StickySessionsRepository
from app.modules.request_logs.repository import RequestLogsRepository
Expand Down Expand Up @@ -153,6 +154,7 @@ async def _proxy_repo_context() -> AsyncIterator[ProxyRepositories]:
sticky_sessions=StickySessionsRepository(session),
api_keys=ApiKeysRepository(session),
additional_usage=AdditionalUsageRepository(session),
response_snapshots=ResponseSnapshotsRepository(session),
)


Expand Down
14 changes: 14 additions & 0 deletions app/modules/proxy/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def select_account(
routing_strategy: RoutingStrategy = "usage_weighted",
model: str | None = None,
additional_limit_name: str | None = None,
preferred_account_id: str | None = None,
exclude_account_ids: Collection[str] | None = None,
) -> AccountSelection:
selection_inputs = await self._load_selection_inputs(
Expand Down Expand Up @@ -134,6 +135,7 @@ async def select_account(
sticky_max_age_seconds=sticky_max_age_seconds,
prefer_earlier_reset_accounts=prefer_earlier_reset_accounts,
routing_strategy=routing_strategy,
preferred_account_id=preferred_account_id,
sticky_repo=repos.sticky_sessions,
)
if result.account is not None:
Expand Down Expand Up @@ -347,8 +349,20 @@ async def _select_with_stickiness(
sticky_max_age_seconds: int | None,
prefer_earlier_reset_accounts: bool,
routing_strategy: RoutingStrategy,
preferred_account_id: str | None,
sticky_repo: StickySessionsRepository | None,
) -> SelectionResult:
if preferred_account_id:
preferred_state = next((state for state in states if state.account_id == preferred_account_id), None)
if preferred_state is not None:
preferred_result = select_account(
[preferred_state],
prefer_earlier_reset=prefer_earlier_reset_accounts,
routing_strategy=routing_strategy,
allow_backoff_fallback=False,
)
if preferred_result.account is not None:
return preferred_result
if not sticky_key or not sticky_repo:
return select_account(
states,
Expand Down
2 changes: 2 additions & 0 deletions app/modules/proxy/repo_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from app.modules.accounts.repository import AccountsRepository
from app.modules.api_keys.repository import ApiKeysRepository
from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository
from app.modules.proxy.sticky_repository import StickySessionsRepository
from app.modules.request_logs.repository import RequestLogsRepository
from app.modules.usage.repository import AdditionalUsageRepository, UsageRepository
Expand All @@ -19,6 +20,7 @@ class ProxyRepositories:
sticky_sessions: StickySessionsRepository
api_keys: ApiKeysRepository
additional_usage: AdditionalUsageRepository
response_snapshots: ResponseSnapshotsRepository | None = None


ProxyRepoFactory = Callable[[], AsyncContextManager[ProxyRepositories]]
95 changes: 95 additions & 0 deletions app/modules/proxy/response_snapshots_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import json

from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Insert

from app.core.types import JsonValue
from app.db.models import ResponseSnapshot


class ResponseSnapshotsRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session

async def get(self, response_id: str, *, api_key_id: str | None) -> ResponseSnapshot | None:
if not response_id:
return None
statement = select(ResponseSnapshot).where(ResponseSnapshot.response_id == response_id)
if api_key_id is None:
statement = statement.where(ResponseSnapshot.api_key_id.is_(None))
else:
statement = statement.where(ResponseSnapshot.api_key_id == api_key_id)
result = await self._session.execute(statement)
return result.scalar_one_or_none()

async def upsert(
self,
*,
response_id: str,
parent_response_id: str | None,
account_id: str | None,
api_key_id: str | None,
model: str,
input_items: list[JsonValue],
response_payload: dict[str, JsonValue],
) -> ResponseSnapshot:
statement = self._build_upsert_statement(
response_id=response_id,
parent_response_id=parent_response_id,
account_id=account_id,
api_key_id=api_key_id,
model=model,
input_items_json=json.dumps(input_items, ensure_ascii=False, separators=(",", ":")),
response_json=json.dumps(response_payload, ensure_ascii=False, separators=(",", ":")),
)
await self._session.execute(statement)
await self._session.commit()
snapshot = await self.get(response_id, api_key_id=api_key_id)
if snapshot is None:
raise RuntimeError(f"ResponseSnapshot upsert failed for response_id={response_id!r}")
await self._session.refresh(snapshot)
return snapshot

def _build_upsert_statement(
self,
*,
response_id: str,
parent_response_id: str | None,
account_id: str | None,
api_key_id: str | None,
model: str,
input_items_json: str,
response_json: str,
) -> Insert:
dialect = self._session.get_bind().dialect.name
if dialect == "postgresql":
insert_fn = pg_insert
elif dialect == "sqlite":
insert_fn = sqlite_insert
else:
raise RuntimeError(f"ResponseSnapshot upsert unsupported for dialect={dialect!r}")
statement = insert_fn(ResponseSnapshot).values(
response_id=response_id,
parent_response_id=parent_response_id,
account_id=account_id,
api_key_id=api_key_id,
model=model,
input_items_json=input_items_json,
response_json=response_json,
)
return statement.on_conflict_do_update(
index_elements=[ResponseSnapshot.response_id],
set_={
"parent_response_id": parent_response_id,
"account_id": account_id,
"api_key_id": api_key_id,
"model": model,
"input_items_json": input_items_json,
"response_json": response_json,
},
)
Loading
Loading