diff --git a/.env.example b/.env.example index d2781fd6..16846ea4 100644 --- a/.env.example +++ b/.env.example @@ -25,6 +25,8 @@ CODEX_LB_OAUTH_CALLBACK_HOST=127.0.0.1 CODEX_LB_OAUTH_CALLBACK_PORT=1455 CODEX_LB_TOKEN_REFRESH_TIMEOUT_SECONDS=30 CODEX_LB_TOKEN_REFRESH_INTERVAL_DAYS=8 +# Optional direct refresh endpoint override (used by refresh token exchange) +# CODEX_REFRESH_TOKEN_URL_OVERRIDE=https://auth.openai.com/oauth/token # Encryption key file (optional override; recommended for Docker volumes) # CODEX_LB_ENCRYPTION_KEY_FILE=/var/lib/codex-lb/encryption.key @@ -37,6 +39,13 @@ CODEX_LB_USAGE_REFRESH_INTERVAL_SECONDS=60 CODEX_LB_STICKY_SESSION_CLEANUP_ENABLED=true CODEX_LB_STICKY_SESSION_CLEANUP_INTERVAL_SECONDS=300 +# Optional outbound HTTP proxy for upstream/OAuth/model requests +# CODEX_LB_HTTP_PROXY_URL=http://127.0.0.1:8080 + +# Optional additional proxy request guard (in addition to Authorization auth, if enabled) +# CODEX_LB_PROXY_KEY_AUTH_ENABLED=false +# CODEX_LB_PROXY_KEY=replace-with-strong-random-shared-key + # Firewall # Trust X-Forwarded-For for firewall client IP detection (enable only behind trusted reverse proxy) CODEX_LB_FIREWALL_TRUST_PROXY_HEADERS=false diff --git a/.gitignore b/.gitignore index 0ead1932..fe87cae7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ __pycache__/ .env.* !.env.example .python-version +*.iml # Build artifacts build/ diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 00000000..94419b7d --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,5 @@ +@Library(['dada-tuda-jenkins-pipelines@develop', 'maven-lib@1.0.10']) _ +pythonPipeline( + projectName: "codex-proxy", + python: "3.14" +) diff --git a/README.md b/README.md index a7e61e0a..934cd8ae 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ Load balancer for ChatGPT accounts. Pool multiple accounts, track usage, manage +- Accounts UI supports batch `auth.json` import and one-click auth archive export. + ## Quick Start ```bash @@ -283,6 +285,8 @@ When enabled, clients must pass a valid API key as a Bearer token: Authorization: Bearer sk-clb-... ``` +Optional extra hardening: enable `CODEX_LB_PROXY_KEY_AUTH_ENABLED=true` with `CODEX_LB_PROXY_KEY` to require `X-Codex-Proxy-Key` on proxy requests in addition to Bearer auth. + **Creating keys**: Dashboard → API Keys → Create. The full key is shown **only once** at creation. Keys support optional expiration, model restrictions, and rate limits (tokens / cost per day / week / month). ## Configuration @@ -290,6 +294,7 @@ Authorization: Bearer sk-clb-... Environment variables with `CODEX_LB_` prefix or `.env.local`. See [`.env.example`](.env.example). Dashboard auth is configured in Settings. SQLite is the default database backend; PostgreSQL is optional via `CODEX_LB_DATABASE_URL` (for example `postgresql+asyncpg://...`). +Container startup also honors `PORT` and auto-loads `/app/.env` when that file is mounted. ## Data diff --git a/app/cli.py b/app/cli.py index e7803731..acd1074c 100644 --- a/app/cli.py +++ b/app/cli.py @@ -1,11 +1,14 @@ from __future__ import annotations import argparse +import logging import os import uvicorn -from app.core.runtime_logging import build_log_config +from app.core.logging import build_log_config, configure_logging + +logger = logging.getLogger(__name__) def _parse_args() -> argparse.Namespace: @@ -24,13 +27,23 @@ def main() -> None: if bool(args.ssl_certfile) ^ bool(args.ssl_keyfile): raise SystemExit("Both --ssl-certfile and --ssl-keyfile must be provided together.") + log_level = configure_logging() + logger.info( + "Starting codex-lb host=%s port=%s ssl=%s log_level=%s access_log=%s", + args.host, + args.port, + bool(args.ssl_certfile and args.ssl_keyfile), + log_level, + False, + ) uvicorn.run( "app.main:app", host=args.host, port=args.port, ssl_certfile=args.ssl_certfile, ssl_keyfile=args.ssl_keyfile, - log_config=build_log_config(), + access_log=False, + log_config=build_log_config(log_level), ) diff --git a/app/core/auth/__init__.py b/app/core/auth/__init__.py index 424799e4..6bfcfe06 100644 --- a/app/core/auth/__init__.py +++ b/app/core/auth/__init__.py @@ -4,7 +4,7 @@ import hashlib import json from dataclasses import dataclass -from datetime import datetime +from datetime import UTC, datetime from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field @@ -79,6 +79,18 @@ def extract_id_token_claims(id_token: str) -> IdTokenClaims: return IdTokenClaims() +def token_expiry(token: str | None) -> datetime | None: + if not token: + return None + claims = extract_id_token_claims(token) + exp = claims.exp + if isinstance(exp, (int, float)): + return datetime.fromtimestamp(exp, tz=UTC) + if isinstance(exp, str) and exp.isdigit(): + return datetime.fromtimestamp(int(exp), tz=UTC) + return None + + def claims_from_auth(auth: AuthFile) -> AccountClaims: claims = extract_id_token_claims(auth.tokens.id_token) auth_claims = claims.auth or OpenAIAuthClaims() diff --git a/app/core/auth/dependencies.py b/app/core/auth/dependencies.py index 61d80ac5..7076799d 100644 --- a/app/core/auth/dependencies.py +++ b/app/core/auth/dependencies.py @@ -1,11 +1,13 @@ from __future__ import annotations import logging +import secrets from fastapi import Request, Security from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.core.clients.usage import UsageFetchError, fetch_usage +from app.core.config.settings import get_settings from app.core.config.settings_cache import get_settings_cache from app.core.exceptions import DashboardAuthError, ProxyAuthError, ProxyUpstreamError from app.db.session import get_background_session @@ -34,8 +36,10 @@ def set_dashboard_error_format(request: Request) -> None: async def validate_proxy_api_key( + request: Request, credentials: HTTPAuthorizationCredentials | None = Security(_bearer), ) -> ApiKeyData | None: + _validate_optional_proxy_key_header(request) authorization = None if credentials is None else f"Bearer {credentials.credentials}" return await validate_proxy_api_key_authorization(authorization) @@ -124,3 +128,22 @@ def _extract_bearer_token(authorization: str | None) -> str | None: if not token: return None return token + + +def _validate_optional_proxy_key_header(request: Request) -> None: + settings = get_settings() + if not settings.proxy_key_auth_enabled: + return + + required_key = settings.proxy_key + if not required_key: + raise ProxyAuthError("X-Codex-Proxy-Key auth is enabled but no proxy key is configured") + + provided = request.headers.get("X-Codex-Proxy-Key") + if not provided: + raise ProxyAuthError("Missing X-Codex-Proxy-Key header") + provided_key = provided.strip() + if not provided_key: + raise ProxyAuthError("Missing X-Codex-Proxy-Key header") + if not secrets.compare_digest(provided_key, required_key): + raise ProxyAuthError("Invalid X-Codex-Proxy-Key header") diff --git a/app/core/auth/refresh.py b/app/core/auth/refresh.py index 2c08b379..5c397dab 100644 --- a/app/core/auth/refresh.py +++ b/app/core/auth/refresh.py @@ -2,6 +2,7 @@ import contextvars import logging +import os from dataclasses import dataclass from datetime import datetime, timedelta @@ -11,13 +12,15 @@ from app.core.auth import OpenAIAuthClaims, extract_id_token_claims from app.core.auth.models import OAuthTokenPayload from app.core.balancer import PERMANENT_FAILURE_CODES -from app.core.clients.http import get_http_client +from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs from app.core.config.settings import get_settings from app.core.types import JsonObject from app.core.utils.request_id import get_request_id from app.core.utils.time import to_utc_naive, utcnow TOKEN_REFRESH_INTERVAL_DAYS = 8 +DEFAULT_REFRESH_TOKEN_URL = "https://auth.openai.com/oauth/token" +REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR = "CODEX_REFRESH_TOKEN_URL_OVERRIDE" logger = logging.getLogger(__name__) _TOKEN_REFRESH_TIMEOUT_OVERRIDE: contextvars.ContextVar[float | None] = contextvars.ContextVar( @@ -57,13 +60,20 @@ def classify_refresh_error(code: str | None) -> bool: return code in PERMANENT_FAILURE_CODES +def refresh_token_endpoint() -> str: + override = os.getenv(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR) + if override and override.strip(): + return override.strip() + return DEFAULT_REFRESH_TOKEN_URL + + async def refresh_access_token( refresh_token: str, *, session: aiohttp.ClientSession | None = None, ) -> TokenRefreshResult: settings = get_settings() - url = f"{settings.auth_base_url.rstrip('/')}/oauth/token" + url = refresh_token_endpoint() payload = { "grant_type": "refresh_token", "client_id": settings.oauth_client_id, @@ -77,7 +87,8 @@ async def refresh_access_token( request_id = get_request_id() if request_id: headers["x-request-id"] = request_id - async with client_session.post(url, json=payload, headers=headers, timeout=timeout) as resp: + proxy_kwargs = await get_http_proxy_request_kwargs() + async with client_session.post(url, json=payload, headers=headers, timeout=timeout, **proxy_kwargs) as resp: data = await _safe_json(resp) try: payload_data = OAuthTokenPayload.model_validate(data) @@ -132,9 +143,9 @@ async def _safe_json(resp: aiohttp.ClientResponse) -> JsonObject: def _refresh_error_from_payload(payload: OAuthTokenPayload, status_code: int) -> RefreshError: - code = _extract_error_code(payload) or f"http_{status_code}" message = _extract_error_message(payload) or f"Token refresh failed ({status_code})" - return RefreshError(code, message, classify_refresh_error(code)) + code = _normalize_refresh_error_code(_extract_error_code(payload), message, status_code) + return RefreshError(code, message, _is_permanent_refresh_failure(code, message, status_code)) def _effective_token_refresh_timeout(configured_timeout_seconds: float) -> float: @@ -162,3 +173,44 @@ def _extract_error_message(payload: OAuthTokenPayload) -> str | None: if isinstance(error, str): return payload.error_description or error return payload.message + + +def _is_permanent_refresh_failure(code: str | None, message: str, status_code: int) -> bool: + if classify_refresh_error(code): + return True + + normalized_code = (code or "").strip().lower() + normalized_message = message.strip().lower() + if status_code != 401: + return False + + permanent_codes = { + "invalid_grant", + "token_expired", + "session_expired", + } + if normalized_code in permanent_codes: + return True + + permanent_message_fragments = ( + "refresh token has already been used", + "provided authentication token is expired", + "please try signing in again", + "re-login required", + "token is expired", + "token expired", + ) + return any(fragment in normalized_message for fragment in permanent_message_fragments) + + +def _normalize_refresh_error_code(code: str | None, message: str, status_code: int) -> str: + normalized_code = (code or "").strip().lower() + normalized_message = message.strip().lower() + + if status_code == 401: + if "refresh token has already been used" in normalized_message: + return "refresh_token_reused" + if "provided authentication token is expired" in normalized_message or "token expired" in normalized_message: + return "refresh_token_expired" + + return code or f"http_{status_code}" diff --git a/app/core/clients/codex_version.py b/app/core/clients/codex_version.py index 5ac02644..bbb54b90 100644 --- a/app/core/clients/codex_version.py +++ b/app/core/clients/codex_version.py @@ -7,6 +7,7 @@ import aiohttp import anyio +from app.core.clients.http import get_http_proxy_request_kwargs from app.core.config.settings import get_settings logger = logging.getLogger(__name__) @@ -59,9 +60,10 @@ async def invalidate(self) -> None: async def _fetch_latest_version(self) -> str | None: try: timeout = aiohttp.ClientTimeout(total=_FETCH_TIMEOUT_SECONDS) + proxy_kwargs = await get_http_proxy_request_kwargs() async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: headers = {"Accept": "application/vnd.github+json"} - async with session.get(_GITHUB_RELEASES_URL, headers=headers) as resp: + async with session.get(_GITHUB_RELEASES_URL, headers=headers, **proxy_kwargs) as resp: if resp.status != 200: logger.warning("GitHub releases API returned HTTP %d", resp.status) return None diff --git a/app/core/clients/http.py b/app/core/clients/http.py index 7b53f1d8..41ec76d7 100644 --- a/app/core/clients/http.py +++ b/app/core/clients/http.py @@ -5,8 +5,10 @@ import aiohttp from aiohttp_retry import RetryClient +from app.core.config.proxy import normalize_http_proxy_url +from app.core.config.settings import get_settings +from app.core.config.settings_cache import get_settings_cache from app.core.config.settings import get_settings - @dataclass(slots=True) class HttpClient: @@ -59,3 +61,22 @@ def get_http_client() -> HttpClient: if _http_client is None: raise RuntimeError("HTTP client not initialized") return _http_client + + +async def get_http_proxy_url() -> str | None: + env_proxy = normalize_http_proxy_url(get_settings().http_proxy_url) + if env_proxy: + return env_proxy + + try: + settings_row = await get_settings_cache().get() + except Exception: + return None + return normalize_http_proxy_url(getattr(settings_row, "http_proxy_url", None)) + + +async def get_http_proxy_request_kwargs() -> dict[str, str]: + proxy = await get_http_proxy_url() + if not proxy: + return {} + return {"proxy": proxy} diff --git a/app/core/clients/model_fetcher.py b/app/core/clients/model_fetcher.py index 0d41e744..36bb6da0 100644 --- a/app/core/clients/model_fetcher.py +++ b/app/core/clients/model_fetcher.py @@ -6,7 +6,7 @@ import aiohttp from app.core.clients.codex_version import get_codex_version_cache -from app.core.clients.http import get_http_client +from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs from app.core.config.settings import get_settings from app.core.openai.model_registry import ReasoningLevel, UpstreamModel from app.core.types import JsonValue @@ -99,8 +99,9 @@ async def fetch_models_for_plan( timeout = aiohttp.ClientTimeout(total=_FETCH_TIMEOUT_SECONDS) session = get_http_client().session + proxy_kwargs = await get_http_proxy_request_kwargs() - async with session.get(url, headers=headers, timeout=timeout) as resp: + async with session.get(url, headers=headers, timeout=timeout, **proxy_kwargs) as resp: if resp.status >= 400: text = await resp.text() raise ModelFetchError(resp.status, f"HTTP {resp.status}: {text[:200]}") diff --git a/app/core/clients/oauth.py b/app/core/clients/oauth.py index d14b75fb..d1d75a14 100644 --- a/app/core/clients/oauth.py +++ b/app/core/clients/oauth.py @@ -12,7 +12,7 @@ from pydantic import ValidationError from app.core.auth.models import DeviceCodePayload, OAuthTokenPayload -from app.core.clients.http import get_http_client +from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs from app.core.config.settings import get_settings from app.core.types import JsonObject from app.core.utils.request_id import get_request_id @@ -109,11 +109,13 @@ async def exchange_authorization_code( request_id = get_request_id() if request_id: headers["x-request-id"] = request_id + proxy_kwargs = await get_http_proxy_request_kwargs() async with client_session.post( url, data=encoded, headers=headers, timeout=timeout, + **proxy_kwargs, ) as resp: data = await _safe_json(resp) try: @@ -155,7 +157,8 @@ async def request_device_code( request_id = get_request_id() if request_id: headers["x-request-id"] = request_id - async with client_session.post(url, json=payload, headers=headers, timeout=timeout) as resp: + proxy_kwargs = await get_http_proxy_request_kwargs() + async with client_session.post(url, json=payload, headers=headers, timeout=timeout, **proxy_kwargs) as resp: data = await _safe_json(resp) if resp.status >= 400: if resp.status == 404: @@ -223,7 +226,8 @@ async def exchange_device_token( request_id = get_request_id() if request_id: headers["x-request-id"] = request_id - async with client_session.post(url, json=payload, headers=headers, timeout=timeout) as resp: + proxy_kwargs = await get_http_proxy_request_kwargs() + async with client_session.post(url, json=payload, headers=headers, timeout=timeout, **proxy_kwargs) as resp: data = await _safe_json(resp) try: payload_data = OAuthTokenPayload.model_validate(data) diff --git a/app/core/clients/proxy.py b/app/core/clients/proxy.py index 79b4c240..ebd66f52 100644 --- a/app/core/clients/proxy.py +++ b/app/core/clients/proxy.py @@ -20,7 +20,7 @@ from aiohttp.client_ws import DEFAULT_WS_CLIENT_TIMEOUT from multidict import CIMultiDict -from app.core.clients.http import get_http_client +from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs from app.core.config.settings import get_settings from app.core.errors import OpenAIErrorEnvelope, ResponseFailedEvent, openai_error, response_failed_event from app.core.openai.model_registry import get_model_registry @@ -1327,6 +1327,7 @@ async def stream_responses( error_code: str | None = None error_message: str | None = None client_session = session or get_http_client().session + proxy_kwargs = await get_http_proxy_request_kwargs() payload_dict = payload.to_payload() if settings.image_inline_fetch_enabled: payload_dict = await _inline_input_image_urls( @@ -1373,6 +1374,7 @@ async def _stream_via_http( json=payload_dict, headers=current_headers, timeout=current_timeout, + **proxy_kwargs, ) as resp: status_code = resp.status if resp.status >= 400: @@ -1709,6 +1711,7 @@ async def execute(self) -> CompactResponsePayload: pre_request_started_at = time.monotonic() compact_timeout_seconds = _effective_compact_total_timeout(settings.upstream_compact_timeout_seconds) effective_connect_timeout = _effective_compact_connect_timeout(settings.upstream_connect_timeout_seconds) + proxy_kwargs = await get_http_proxy_request_kwargs() payload_dict = self.payload.to_payload() if settings.image_inline_fetch_enabled: payload_dict = await _inline_input_image_urls( @@ -1761,6 +1764,7 @@ async def execute(self) -> CompactResponsePayload: json=payload_dict, headers=upstream_headers, timeout=timeout, + **proxy_kwargs, ) as resp: status_code = resp.status if resp.status >= 400: @@ -1921,6 +1925,7 @@ async def transcribe_audio( form.add_field("prompt", prompt) client_session = session or get_http_client().session + proxy_kwargs = await get_http_proxy_request_kwargs() started_at = time.monotonic() status_code: int | None = None error_code: str | None = None @@ -1947,6 +1952,7 @@ async def transcribe_audio( data=form, headers=upstream_headers, timeout=timeout, + **proxy_kwargs, ) as resp: status_code = resp.status if resp.status >= 400: diff --git a/app/core/clients/usage.py b/app/core/clients/usage.py index 3cd59224..074d42fd 100644 --- a/app/core/clients/usage.py +++ b/app/core/clients/usage.py @@ -2,12 +2,13 @@ import asyncio import logging +import re import aiohttp from aiohttp_retry import ExponentialRetry, RetryClient from pydantic import BaseModel, ConfigDict, ValidationError -from app.core.clients.http import get_http_client +from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs from app.core.config.settings import get_settings from app.core.types import JsonObject from app.core.usage.models import UsagePayload @@ -18,6 +19,7 @@ RETRY_MAX_TIMEOUT = 2.0 logger = logging.getLogger(__name__) +_HTML_TAG_RE = re.compile(r"<[^>]+>") class UsageErrorDetail(BaseModel): @@ -59,6 +61,7 @@ async def fetch_usage( headers = _usage_headers(access_token, account_id) retry_client = client or get_http_client().retry_client retry_options = _retry_options(retries + 1) + proxy_kwargs = await get_http_proxy_request_kwargs() try: async with retry_client.request( @@ -67,6 +70,7 @@ async def fetch_usage( headers=headers, timeout=timeout, retry_options=retry_options, + **proxy_kwargs, ) as resp: data = await _safe_json(resp) if resp.status >= 400: @@ -117,7 +121,7 @@ async def _safe_json(resp: aiohttp.ClientResponse) -> JsonObject: data = await resp.json(content_type=None) except Exception: text = await resp.text() - return {"error": {"message": text.strip()}} + return {"error": {"message": _sanitize_error_text(text)}} return data if isinstance(data, dict) else {"error": {"message": str(data)}} @@ -125,10 +129,32 @@ def _extract_error_message(payload: JsonObject) -> str | None: envelope = UsageErrorEnvelope.model_validate(payload) error = envelope.error if isinstance(error, UsageErrorDetail): - return error.message or error.error_description + message = error.message or error.error_description + return _sanitize_error_text(message) if isinstance(error, str): - return envelope.error_description or error - return envelope.message + return _sanitize_error_text(envelope.error_description or error) + return _sanitize_error_text(envelope.message) + + +def _sanitize_error_text(message: str | None) -> str | None: + if message is None: + return None + normalized = " ".join(message.strip().split()) + if not normalized: + return None + if _looks_like_html(normalized): + return "Upstream returned an HTML error response" + return normalized + + +def _looks_like_html(value: str) -> bool: + lower = value.lower() + return ( + " ExponentialRetry: diff --git a/app/core/config/proxy.py b/app/core/config/proxy.py new file mode 100644 index 00000000..a5b00794 --- /dev/null +++ b/app/core/config/proxy.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from urllib.parse import urlparse + + +def normalize_http_proxy_url(value: str | None) -> str | None: + if value is None: + return None + normalized = value.strip() + if not normalized: + return None + + parsed = urlparse(normalized) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError("http_proxy_url must be a valid http or https URL") + + return normalized diff --git a/app/core/config/settings.py b/app/core/config/settings.py index bfdced5d..366ce4ea 100644 --- a/app/core/config/settings.py +++ b/app/core/config/settings.py @@ -8,6 +8,8 @@ from pydantic import Field, field_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict +from app.core.config.proxy import normalize_http_proxy_url + BASE_DIR = Path(__file__).resolve().parents[3] DOCKER_DATA_DIR = Path("/var/lib/codex-lb") @@ -78,6 +80,7 @@ class Settings(BaseSettings): openai_cache_affinity_max_age_seconds: int = Field(default=300, gt=0) sticky_session_cleanup_enabled: bool = True sticky_session_cleanup_interval_seconds: int = Field(default=300, gt=0) + http_proxy_url: str | None = None encryption_key_file: Path = DEFAULT_ENCRYPTION_KEY_FILE database_migrations_fail_fast: bool = True log_proxy_request_shape: bool = False @@ -96,6 +99,8 @@ class Settings(BaseSettings): firewall_trusted_proxy_cidrs: Annotated[list[str], NoDecode] = Field( default_factory=lambda: ["127.0.0.1/32", "::1/128"] ) + proxy_key_auth_enabled: bool = False + proxy_key: str | None = None @field_validator("database_url") @classmethod @@ -159,6 +164,25 @@ def _normalize_firewall_trusted_proxy_cidrs(cls, value: object) -> list[str]: raise ValueError(f"Invalid firewall trusted proxy CIDR: {cidr}") from exc return cidrs + @field_validator("http_proxy_url", mode="before") + @classmethod + def _normalize_http_proxy_url(cls, value: object) -> str | None: + if value is None: + return None + if isinstance(value, str): + return normalize_http_proxy_url(value) + raise TypeError("http_proxy_url must be a string") + + @field_validator("proxy_key", mode="before") + @classmethod + def _normalize_proxy_key(cls, value: object) -> str | None: + if value is None: + return None + if isinstance(value, str): + key = value.strip() + return key or None + raise TypeError("proxy_key must be a string") + @field_validator("upstream_compact_timeout_seconds") @classmethod def _validate_upstream_compact_timeout_seconds(cls, value: float | None) -> float | None: diff --git a/app/core/logging.py b/app/core/logging.py new file mode 100644 index 00000000..d1f345db --- /dev/null +++ b/app/core/logging.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import logging +import os +from logging.config import dictConfig +from typing import Any + +DEFAULT_LOG_LEVEL = "INFO" +_ALLOWED_LOG_LEVELS = {"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"} + + +def resolve_log_level(raw_level: str | None) -> str: + if raw_level is None: + return DEFAULT_LOG_LEVEL + normalized = raw_level.strip().upper() + if normalized in _ALLOWED_LOG_LEVELS: + return normalized + return DEFAULT_LOG_LEVEL + + +def build_log_config(level: str | None = None) -> dict[str, Any]: + resolved_level = resolve_log_level(level or os.getenv("CODEX_LB_LOG_LEVEL")) + formatter = { + "format": "%(asctime)s %(levelname)s %(name)s %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + } + return { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"standard": formatter}, + "handlers": { + "default": { + "class": "logging.StreamHandler", + "formatter": "standard", + "stream": "ext://sys.stdout", + } + }, + "root": {"handlers": ["default"], "level": resolved_level}, + "loggers": { + "uvicorn": {"level": resolved_level}, + "uvicorn.error": {"level": resolved_level, "propagate": True}, + "uvicorn.access": {"handlers": [], "level": "WARNING", "propagate": False}, + }, + } + + +def configure_logging(level: str | None = None) -> str: + resolved_level = resolve_log_level(level or os.getenv("CODEX_LB_LOG_LEVEL")) + dictConfig(build_log_config(resolved_level)) + logging.captureWarnings(True) + return resolved_level diff --git a/app/db/alembic/versions/20260309_000000_add_dashboard_settings_http_proxy_url.py b/app/db/alembic/versions/20260309_000000_add_dashboard_settings_http_proxy_url.py new file mode 100644 index 00000000..132c4f02 --- /dev/null +++ b/app/db/alembic/versions/20260309_000000_add_dashboard_settings_http_proxy_url.py @@ -0,0 +1,45 @@ +"""add dashboard_settings.http_proxy_url + +Revision ID: 20260309_000000_add_dashboard_settings_http_proxy_url +Revises: 20260228_030000_add_api_firewall_allowlist +Create Date: 2026-03-09 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.engine import Connection + +# revision identifiers, used by Alembic. +revision = "20260309_000000_add_dashboard_settings_http_proxy_url" +down_revision = "20260228_030000_add_api_firewall_allowlist" +branch_labels = None +depends_on = None + + +def _table_exists(connection: Connection, table_name: str) -> bool: + inspector = sa.inspect(connection) + return inspector.has_table(table_name) + + +def _columns(connection: Connection, table_name: str) -> set[str]: + inspector = sa.inspect(connection) + if not inspector.has_table(table_name): + return set() + return {str(column["name"]) for column in inspector.get_columns(table_name) if column.get("name") is not None} + + +def upgrade() -> None: + bind = op.get_bind() + if not _table_exists(bind, "dashboard_settings"): + return + columns = _columns(bind, "dashboard_settings") + if "http_proxy_url" in columns: + return + with op.batch_alter_table("dashboard_settings") as batch_op: + batch_op.add_column(sa.Column("http_proxy_url", sa.Text(), nullable=True)) + + +def downgrade() -> None: + return diff --git a/app/db/alembic/versions/20260313_170000_normalize_legacy_enum_value_casing.py b/app/db/alembic/versions/20260313_170000_normalize_legacy_enum_value_casing.py new file mode 100644 index 00000000..0edf24e3 --- /dev/null +++ b/app/db/alembic/versions/20260313_170000_normalize_legacy_enum_value_casing.py @@ -0,0 +1,95 @@ +"""normalize legacy uppercase enum-like values to lowercase + +Revision ID: 20260313_170000_normalize_legacy_enum_value_casing +Revises: 20260312_120000_add_dashboard_upstream_stream_transport +Create Date: 2026-03-13 17:00:00.000000 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.engine import Connection + +# revision identifiers, used by Alembic. +revision = "20260313_170000_normalize_legacy_enum_value_casing" +down_revision = "20260312_120000_add_dashboard_upstream_stream_transport" +branch_labels = None +depends_on = None + + +_ACCOUNT_STATUS_RENAMES: tuple[tuple[str, str], ...] = ( + ("ACTIVE", "active"), + ("RATE_LIMITED", "rate_limited"), + ("QUOTA_EXCEEDED", "quota_exceeded"), + ("PAUSED", "paused"), + ("DEACTIVATED", "deactivated"), +) + +_API_KEY_LIMIT_TYPE_RENAMES: tuple[tuple[str, str], ...] = ( + ("TOTAL_TOKENS", "total_tokens"), + ("INPUT_TOKENS", "input_tokens"), + ("OUTPUT_TOKENS", "output_tokens"), + ("COST_USD", "cost_usd"), +) + +_API_KEY_LIMIT_WINDOW_RENAMES: tuple[tuple[str, str], ...] = ( + ("DAILY", "daily"), + ("WEEKLY", "weekly"), + ("MONTHLY", "monthly"), +) + + +def _table_has_column(connection: Connection, table_name: str, column_name: str) -> bool: + inspector = sa.inspect(connection) + if not inspector.has_table(table_name): + return False + columns = inspector.get_columns(table_name) + return any(str(column.get("name")) == column_name for column in columns) + + +def _normalize_column_values( + connection: Connection, + *, + table_name: str, + column_name: str, + renames: tuple[tuple[str, str], ...], +) -> None: + if not _table_has_column(connection, table_name, column_name): + return + + for old_value, new_value in renames: + connection.execute( + sa.text( + f"UPDATE {table_name} " + f"SET {column_name} = :new_value " + f"WHERE CAST({column_name} AS TEXT) = :old_value" + ), + {"old_value": old_value, "new_value": new_value}, + ) + + +def upgrade() -> None: + bind = op.get_bind() + _normalize_column_values( + bind, + table_name="accounts", + column_name="status", + renames=_ACCOUNT_STATUS_RENAMES, + ) + _normalize_column_values( + bind, + table_name="api_key_limits", + column_name="limit_type", + renames=_API_KEY_LIMIT_TYPE_RENAMES, + ) + _normalize_column_values( + bind, + table_name="api_key_limits", + column_name="limit_window", + renames=_API_KEY_LIMIT_WINDOW_RENAMES, + ) + + +def downgrade() -> None: + return diff --git a/app/db/alembic/versions/20260313_190000_merge_dashboard_and_enum_heads.py b/app/db/alembic/versions/20260313_190000_merge_dashboard_and_enum_heads.py new file mode 100644 index 00000000..885b0280 --- /dev/null +++ b/app/db/alembic/versions/20260313_190000_merge_dashboard_and_enum_heads.py @@ -0,0 +1,27 @@ +"""merge dashboard settings and enum normalization heads + +Revision ID: 20260313_190000_merge_dashboard_and_enum_heads +Revises: + 20260309_000000_add_dashboard_settings_http_proxy_url, + 20260313_170000_normalize_legacy_enum_value_casing +Create Date: 2026-03-13 19:00:00.000000 +""" + +from __future__ import annotations + +# revision identifiers, used by Alembic. +revision = "20260313_190000_merge_dashboard_and_enum_heads" +down_revision = ( + "20260309_000000_add_dashboard_settings_http_proxy_url", + "20260313_170000_normalize_legacy_enum_value_casing", +) +branch_labels = None +depends_on = None + + +def upgrade() -> None: + return + + +def downgrade() -> None: + return diff --git a/app/db/migrate.py b/app/db/migrate.py index 6cb5c795..0062ac67 100644 --- a/app/db/migrate.py +++ b/app/db/migrate.py @@ -541,7 +541,7 @@ async def run_startup_migrations(database_url: str) -> MigrationRunResult: return await to_thread.run_sync( lambda: run_upgrade( database_url, - "head", + "heads", bootstrap_legacy=True, auto_remap_legacy_revisions=auto_remap, ), @@ -570,7 +570,7 @@ def _parse_args() -> argparse.Namespace: subparsers = parser.add_subparsers(dest="command", required=True) upgrade_parser = subparsers.add_parser("upgrade", help="Upgrade schema to a target revision.") - upgrade_parser.add_argument("revision", nargs="?", default="head") + upgrade_parser.add_argument("revision", nargs="?", default="heads") upgrade_parser.add_argument( "--no-bootstrap-legacy", action="store_true", diff --git a/app/db/models.py b/app/db/models.py index 2d434165..6a922f4e 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -185,6 +185,7 @@ class DashboardSettings(Base): server_default=false(), nullable=False, ) + http_proxy_url: Mapped[str | None] = mapped_column(Text, nullable=True) totp_required_on_login: Mapped[bool] = mapped_column( Boolean, default=False, diff --git a/app/main.py b/app/main.py index a5a2907a..a212b9ce 100644 --- a/app/main.py +++ b/app/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from contextlib import asynccontextmanager from pathlib import Path @@ -33,9 +34,12 @@ from app.modules.usage import api as usage_api from app.modules.usage.additional_quota_keys import reload_additional_quota_registry +logger = logging.getLogger(__name__) + @asynccontextmanager async def lifespan(_: FastAPI): + logger.info("Application startup: initializing settings caches, database, HTTP client, and schedulers") await get_settings_cache().invalidate() await get_rate_limit_headers_cache().invalidate() reload_additional_quota_registry() @@ -46,11 +50,13 @@ async def lifespan(_: FastAPI): sticky_session_cleanup_scheduler = build_sticky_session_cleanup_scheduler() await usage_scheduler.start() await model_scheduler.start() + logger.info("Application startup complete") await sticky_session_cleanup_scheduler.start() try: yield finally: + logger.info("Application shutdown: stopping schedulers and closing resources") await sticky_session_cleanup_scheduler.stop() await model_scheduler.stop() await usage_scheduler.stop() @@ -58,6 +64,7 @@ async def lifespan(_: FastAPI): await close_http_client() finally: await close_db() + logger.info("Application shutdown complete") def create_app() -> FastAPI: @@ -92,10 +99,12 @@ def create_app() -> FastAPI: app.include_router(health_api.router) static_dir = Path(__file__).parent / "static" + static_dir.mkdir(parents=True, exist_ok=True) index_html = static_dir / "index.html" static_root = static_dir.resolve() frontend_build_hint = "Frontend assets are missing. Run `cd frontend && bun run build`." excluded_prefixes = ("api/", "v1/", "backend-api/", "health") + logger.debug("Created FastAPI application and registered API routers") def _is_static_asset_path(path: str) -> bool: if path.startswith("assets/"): diff --git a/app/modules/accounts/api.py b/app/modules/accounts/api.py index c75fdbcf..99de380f 100644 --- a/app/modules/accounts/api.py +++ b/app/modules/accounts/api.py @@ -1,6 +1,6 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, File, UploadFile +from fastapi import APIRouter, Depends, File, Response, UploadFile from app.core.auth.dependencies import set_dashboard_error_format, validate_dashboard_session from app.core.exceptions import DashboardBadRequestError, DashboardConflictError, DashboardNotFoundError @@ -8,13 +8,14 @@ from app.modules.accounts.repository import AccountIdentityConflictError from app.modules.accounts.schemas import ( AccountDeleteResponse, + AccountImportBatchResponse, AccountImportResponse, AccountPauseResponse, AccountReactivateResponse, AccountsResponse, AccountTrendsResponse, ) -from app.modules.accounts.service import InvalidAuthJsonError +from app.modules.accounts.service import AuthRefreshFailedError, ImportFilePayload, InvalidAuthJsonError router = APIRouter( prefix="/api/accounts", @@ -49,11 +50,37 @@ async def import_account( ) -> AccountImportResponse: raw = await auth_json.read() try: - return await context.service.import_account(raw) + return await context.service.import_account(raw, filename=auth_json.filename) except InvalidAuthJsonError as exc: raise DashboardBadRequestError("Invalid auth.json payload", code="invalid_auth_json") from exc except AccountIdentityConflictError as exc: raise DashboardConflictError(str(exc), code="duplicate_identity_conflict") from exc + except AuthRefreshFailedError as exc: + raise DashboardBadRequestError(exc.message, code=exc.code) from exc + + +@router.post("/import/batch", response_model=AccountImportBatchResponse) +async def import_accounts( + auth_json: list[UploadFile] = File(...), + context: AccountsContext = Depends(get_accounts_context), +) -> AccountImportBatchResponse: + files = [ + ImportFilePayload(filename=file.filename, raw=await file.read()) + for file in auth_json + ] + return await context.service.import_accounts(files) + + +@router.get("/export") +async def export_accounts( + context: AccountsContext = Depends(get_accounts_context), +) -> Response: + filename, archive = await context.service.export_accounts_archive() + return Response( + content=archive, + media_type="application/zip", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) @router.post("/{account_id}/reactivate", response_model=AccountReactivateResponse) diff --git a/app/modules/accounts/auth_manager.py b/app/modules/accounts/auth_manager.py index 39577c6b..3abb89c6 100644 --- a/app/modules/accounts/auth_manager.py +++ b/app/modules/accounts/auth_manager.py @@ -55,7 +55,7 @@ async def refresh_account(self, account: Account) -> Account: result = await refresh_access_token(refresh_token) except RefreshError as exc: if exc.is_permanent: - reason = PERMANENT_FAILURE_CODES.get(exc.code, exc.message) + reason = _permanent_refresh_failure_reason(exc) await self._repo.update_status(account.id, AccountStatus.DEACTIVATED, reason) account.status = AccountStatus.DEACTIVATED account.deactivation_reason = reason @@ -121,3 +121,15 @@ def _chatgpt_account_id_from_id_token(id_token: str) -> str | None: claims = extract_id_token_claims(id_token) auth_claims = claims.auth or OpenAIAuthClaims() return auth_claims.chatgpt_account_id or claims.chatgpt_account_id + + +def _permanent_refresh_failure_reason(exc: RefreshError) -> str: + if exc.code in PERMANENT_FAILURE_CODES: + return PERMANENT_FAILURE_CODES[exc.code] + + normalized_message = exc.message.strip().lower() + if "refresh token has already been used" in normalized_message: + return PERMANENT_FAILURE_CODES["refresh_token_reused"] + if "provided authentication token is expired" in normalized_message or "token expired" in normalized_message: + return PERMANENT_FAILURE_CODES["refresh_token_expired"] + return exc.message diff --git a/app/modules/accounts/mappers.py b/app/modules/accounts/mappers.py index a73f1464..1d5fd67a 100644 --- a/app/modules/accounts/mappers.py +++ b/app/modules/accounts/mappers.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from app.core import usage as usage_core -from app.core.auth import DEFAULT_PLAN, extract_id_token_claims +from app.core.auth import DEFAULT_PLAN, extract_id_token_claims, token_expiry from app.core.crypto import TokenEncryptor from app.core.plan_types import coerce_account_plan_type from app.core.usage.types import UsageTrendBucket, UsageWindowRow @@ -175,15 +175,7 @@ def _decrypt_token(encryptor: TokenEncryptor, encrypted: bytes | None) -> str | def _token_expiry(token: str | None) -> datetime | None: - if not token: - return None - claims = extract_id_token_claims(token) - exp = claims.exp - if isinstance(exp, (int, float)): - return datetime.fromtimestamp(exp, tz=timezone.utc) - if isinstance(exp, str) and exp.isdigit(): - return datetime.fromtimestamp(int(exp), tz=timezone.utc) - return None + return token_expiry(token) def _normalize_used_percent(entry: UsageHistory | None) -> float | None: diff --git a/app/modules/accounts/schemas.py b/app/modules/accounts/schemas.py index 0aff2561..a0909a14 100644 --- a/app/modules/accounts/schemas.py +++ b/app/modules/accounts/schemas.py @@ -83,10 +83,23 @@ class AccountsResponse(DashboardModel): class AccountImportResponse(DashboardModel): + filename: str | None = None account_id: str email: str plan_type: str status: str + refreshed_on_import: bool = False + + +class AccountImportFailure(DashboardModel): + filename: str | None = None + code: str + message: str + + +class AccountImportBatchResponse(DashboardModel): + imported: list[AccountImportResponse] = Field(default_factory=list) + failed: list[AccountImportFailure] = Field(default_factory=list) class AccountPauseResponse(DashboardModel): diff --git a/app/modules/accounts/service.py b/app/modules/accounts/service.py index 4eeba85a..a6feb096 100644 --- a/app/modules/accounts/service.py +++ b/app/modules/accounts/service.py @@ -1,25 +1,34 @@ from __future__ import annotations +import io import json +import zipfile +from dataclasses import dataclass from datetime import timedelta from typing import cast from pydantic import ValidationError from app.core.auth import ( + AuthFile, + AuthTokens, DEFAULT_EMAIL, DEFAULT_PLAN, claims_from_auth, generate_unique_account_id, parse_auth_json, + token_expiry, ) from app.core.crypto import TokenEncryptor from app.core.plan_types import coerce_account_plan_type from app.core.utils.time import naive_utc_to_epoch, to_utc_naive, utcnow from app.db.models import Account, AccountStatus +from app.modules.accounts.auth_manager import AuthManager from app.modules.accounts.mappers import build_account_summaries, build_account_usage_trends -from app.modules.accounts.repository import AccountsRepository +from app.modules.accounts.repository import AccountIdentityConflictError, AccountsRepository from app.modules.accounts.schemas import ( + AccountImportBatchResponse, + AccountImportFailure, AccountAdditionalQuota, AccountAdditionalWindow, AccountImportResponse, @@ -30,6 +39,7 @@ from app.modules.usage.additional_quota_keys import get_additional_display_label_for_quota_key from app.modules.usage.repository import AdditionalUsageRepository, UsageRepository from app.modules.usage.updater import AdditionalUsageRepositoryPort, UsageUpdater +from app.core.auth.refresh import RefreshError, TokenRefreshResult, refresh_access_token _SPARKLINE_DAYS = 7 _DETAIL_BUCKET_SECONDS = 3600 # 1h → 168 points @@ -39,6 +49,19 @@ class InvalidAuthJsonError(Exception): pass +class AuthRefreshFailedError(Exception): + def __init__(self, code: str, message: str) -> None: + super().__init__(message) + self.code = code + self.message = message + + +@dataclass(frozen=True) +class ImportFilePayload: + filename: str | None + raw: bytes + + class AccountsService: def __init__( self, @@ -51,6 +74,7 @@ def __init__( self._additional_usage_repo = additional_usage_repo self._usage_updater = UsageUpdater(usage_repo, repo, additional_usage_repo) if usage_repo else None self._encryptor = TokenEncryptor() + self._auth_manager = AuthManager(repo) async def list_accounts(self) -> list[AccountSummary]: accounts = await self._repo.list_accounts() @@ -139,11 +163,12 @@ async def get_account_trends(self, account_id: str) -> AccountTrendsResponse | N secondary=trend.secondary if trend else [], ) - async def import_account(self, raw: bytes) -> AccountImportResponse: + async def import_account(self, raw: bytes, *, filename: str | None = None) -> AccountImportResponse: try: auth = parse_auth_json(raw) except (json.JSONDecodeError, ValidationError, UnicodeDecodeError, TypeError) as exc: raise InvalidAuthJsonError("Invalid auth.json payload") from exc + auth, refreshed_on_import = await self._refresh_import_auth_if_needed(auth) claims = claims_from_auth(auth) email = claims.email or DEFAULT_EMAIL @@ -170,12 +195,56 @@ async def import_account(self, raw: bytes) -> AccountImportResponse: latest_usage = await self._usage_repo.latest_by_account(window="primary") await self._usage_updater.refresh_accounts([saved], latest_usage) return AccountImportResponse( + filename=filename, account_id=saved.id, email=saved.email, plan_type=saved.plan_type, status=saved.status, + refreshed_on_import=refreshed_on_import, ) + async def import_accounts(self, files: list[ImportFilePayload]) -> AccountImportBatchResponse: + imported: list[AccountImportResponse] = [] + failed: list[AccountImportFailure] = [] + + for file in files: + try: + imported.append(await self.import_account(file.raw, filename=file.filename)) + except InvalidAuthJsonError as exc: + failed.append( + AccountImportFailure(filename=file.filename, code="invalid_auth_json", message=str(exc)), + ) + except AccountIdentityConflictError as exc: + failed.append( + AccountImportFailure( + filename=file.filename, + code="duplicate_identity_conflict", + message=str(exc), + ), + ) + except AuthRefreshFailedError as exc: + failed.append( + AccountImportFailure(filename=file.filename, code=exc.code, message=exc.message), + ) + + return AccountImportBatchResponse(imported=imported, failed=failed) + + async def export_accounts_archive(self) -> tuple[str, bytes]: + accounts = await self._repo.list_accounts() + archive_buffer = io.BytesIO() + + with zipfile.ZipFile(archive_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + for account in accounts: + export_account = await self._refresh_account_for_export(account) + payload = self._serialize_account_auth(export_account) + archive.writestr( + f"{_safe_archive_segment(export_account.email)}__{export_account.id}/auth.json", + json.dumps(payload, indent=2), + ) + + filename = f"auth-export-{utcnow().strftime('%Y%m%d-%H%M%S')}.zip" + return filename, archive_buffer.getvalue() + async def reactivate_account(self, account_id: str) -> bool: return await self._repo.update_status(account_id, AccountStatus.ACTIVE, None) @@ -184,3 +253,66 @@ async def pause_account(self, account_id: str) -> bool: async def delete_account(self, account_id: str) -> bool: return await self._repo.delete(account_id) + + async def _refresh_import_auth_if_needed(self, auth: AuthFile) -> tuple[AuthFile, bool]: + expires_at = token_expiry(auth.tokens.access_token) + if not expires_at or to_utc_naive(expires_at) > utcnow() or not auth.tokens.refresh_token: + return auth, False + + try: + result = await refresh_access_token(auth.tokens.refresh_token) + except RefreshError as exc: + raise AuthRefreshFailedError("refresh_failed", exc.message) from exc + return _auth_file_from_refresh_result(result), True + + async def _refresh_account_for_export(self, account: Account) -> Account: + access_token = self._decrypt_token(account.access_token_encrypted) + expires_at = token_expiry(access_token) + should_force = expires_at is not None and to_utc_naive(expires_at) <= utcnow() + + try: + if should_force: + return await self._auth_manager.ensure_fresh(account, force=True) + return await self._auth_manager.ensure_fresh(account) + except RefreshError: + return account + + def _serialize_account_auth(self, account: Account) -> dict[str, object]: + access_token = self._decrypt_token(account.access_token_encrypted) + refresh_token = self._decrypt_token(account.refresh_token_encrypted) + id_token = self._decrypt_token(account.id_token_encrypted) + payload = AuthFile( + tokens=AuthTokens( + idToken=id_token or "", + accessToken=access_token or "", + refreshToken=refresh_token or "", + accountId=account.chatgpt_account_id, + ), + lastRefreshAt=account.last_refresh, + ) + return payload.model_dump(mode="json", by_alias=True, exclude_none=True) + + def _decrypt_token(self, encrypted: bytes | None) -> str | None: + if not encrypted: + return None + try: + return self._encryptor.decrypt(encrypted) + except Exception: + return None + + +def _auth_file_from_refresh_result(result: TokenRefreshResult) -> AuthFile: + return AuthFile( + tokens=AuthTokens( + idToken=result.id_token, + accessToken=result.access_token, + refreshToken=result.refresh_token, + accountId=result.account_id, + ), + lastRefreshAt=utcnow(), + ) + + +def _safe_archive_segment(value: str) -> str: + sanitized = "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "-" for ch in value.strip()) + return sanitized or "account" diff --git a/app/modules/proxy/load_balancer.py b/app/modules/proxy/load_balancer.py index 32a4c569..a9bb5af1 100644 --- a/app/modules/proxy/load_balancer.py +++ b/app/modules/proxy/load_balancer.py @@ -326,6 +326,8 @@ def _prune_runtime(self, accounts: Iterable[Account]) -> None: stale_ids = [account_id for account_id in self._runtime if account_id not in account_ids] for account_id in stale_ids: self._runtime.pop(account_id, None) + if stale_ids: + logger.debug("Pruned stale runtime state account_ids=%s", stale_ids) async def _select_with_stickiness( self, @@ -390,6 +392,7 @@ async def _select_with_stickiness( return chosen async def mark_rate_limit(self, account: Account, error: UpstreamError) -> None: + logger.info("Marking account as rate-limited account_id=%s reset_at=%s", account.id, error.get("resets_at")) async with self._runtime_lock: state = self._state_for(account) handle_rate_limit(state, error) @@ -397,6 +400,7 @@ async def mark_rate_limit(self, account: Account, error: UpstreamError) -> None: await self._sync_state(repos.accounts, account, state) async def mark_quota_exceeded(self, account: Account, error: UpstreamError) -> None: + logger.info("Marking account as quota-exceeded account_id=%s reset_at=%s", account.id, error.get("resets_at")) async with self._runtime_lock: state = self._state_for(account) handle_quota_exceeded(state, error) @@ -404,6 +408,7 @@ async def mark_quota_exceeded(self, account: Account, error: UpstreamError) -> N await self._sync_state(repos.accounts, account, state) async def mark_permanent_failure(self, account: Account, error_code: str) -> None: + logger.info("Marking account as permanently failed account_id=%s error_code=%s", account.id, error_code) async with self._runtime_lock: state = self._state_for(account) handle_permanent_failure(state, error_code) @@ -415,6 +420,7 @@ async def record_error(self, account: Account) -> None: state = self._state_for(account) state.error_count += 1 state.last_error_at = time.time() + logger.debug("Recorded upstream error account_id=%s error_count=%s", account.id, state.error_count) async with self._repo_factory() as repos: await self._sync_state(repos.accounts, account, state) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 1fad9f92..74ee535b 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -179,6 +179,7 @@ async def compact_responses( ) -> CompactResponsePayload: _maybe_log_proxy_request_payload("compact", payload, headers) _maybe_log_proxy_request_shape("compact", payload, headers) + logger.info("Handling compact response request request_id=%s model=%s", ensure_request_id(), payload.model) filtered = filter_inbound_headers(headers) request_id = get_request_id() or ensure_request_id(None) start = time.monotonic() @@ -273,6 +274,12 @@ async def _call_compact(target: Account) -> CompactResponsePayload: response = await _call_compact(account) actual_service_tier = _service_tier_from_response(response) await self._load_balancer.record_success(account) + logger.info( + "Compact response completed request_id=%s account_id=%s model=%s", + get_request_id(), + account.id, + payload.model, + ) await self._settle_compact_api_key_usage( api_key=api_key, api_key_reservation=api_key_reservation, @@ -392,6 +399,12 @@ async def transcribe( headers: Mapping[str, str], api_key: ApiKeyData | None = None, ) -> dict[str, JsonValue]: + logger.info( + "Handling transcription request request_id=%s filename=%s content_type=%s", + ensure_request_id(), + filename, + content_type, + ) filtered = filter_inbound_headers(headers) request_id = get_request_id() or ensure_request_id(None) start = time.monotonic() @@ -1832,8 +1845,22 @@ async def _stream_with_retry( settled = False any_attempt_logged = False settlement = _StreamSettlement() + logger.info( + "Starting stream request request_id=%s model=%s sticky=%s max_attempts=%s", + request_id, + payload.model, + bool(affinity.key), + max_attempts, + ) try: for attempt in range(max_attempts): + logger.debug( + "Selecting account for stream request request_id=%s attempt=%s/%s model=%s", + request_id, + attempt + 1, + max_attempts, + payload.model, + ) remaining_budget = _remaining_budget_seconds(deadline) if remaining_budget <= 0: logger.warning( @@ -1898,6 +1925,14 @@ async def _stream_with_retry( return account = selection.account if not account: + logger.warning( + "No account available for stream request request_id=%s attempt=%s/%s model=%s error=%s", + request_id, + attempt + 1, + max_attempts, + payload.model, + selection.error_message, + ) no_accounts_msg = selection.error_message or "No active accounts available" error_code = selection.error_code or "no_accounts" event = response_failed_event( @@ -1978,6 +2013,14 @@ async def _stream_with_retry( return any_attempt_logged = True settlement = _StreamSettlement() + logger.info( + "Proxying stream request request_id=%s attempt=%s/%s account_id=%s model=%s", + request_id, + attempt + 1, + max_attempts, + account.id, + payload.model, + ) effective_attempt_timeout = _remaining_budget_seconds(deadline) if effective_attempt_timeout <= 0: logger.warning( @@ -2032,8 +2075,24 @@ async def _stream_with_retry( settlement, request_id, ) + logger.info( + "Stream request completed request_id=%s account_id=%s status=%s input_tokens=%s output_tokens=%s", + request_id, + account.id, + settlement.status, + settlement.input_tokens, + settlement.output_tokens, + ) return except _RetryableStreamError as exc: + logger.warning( + "Retryable stream error request_id=%s attempt=%s/%s account_id=%s code=%s", + request_id, + attempt + 1, + max_attempts, + account.id, + exc.code, + ) await self._handle_stream_error(account, exc.error, exc.code) continue except _TerminalStreamError as exc: @@ -2158,6 +2217,12 @@ async def _stream_with_retry( settlement, request_id, ) + logger.info( + "Stream request completed after refresh request_id=%s account_id=%s status=%s", + request_id, + account.id, + settlement.status, + ) return error = _parse_openai_error(exc.payload) error_code = _normalize_error_code(error.code if error else None, error.type if error else None) @@ -2185,6 +2250,15 @@ async def _stream_with_retry( except RefreshError as exc: if exc.is_permanent: await self._load_balancer.mark_permanent_failure(account, exc.code) + logger.warning( + "Account refresh failed during stream request request_id=%s attempt=%s/%s account_id=%s permanent=%s code=%s", + request_id, + attempt + 1, + max_attempts, + account.id, + exc.is_permanent, + exc.code, + ) continue except Exception: logger.warning( @@ -2200,6 +2274,7 @@ async def _stream_with_retry( ) yield format_sse_event(event) return + logger.warning("Stream request exhausted retries request_id=%s model=%s", request_id, payload.model) retries_exhausted_msg = "No available accounts after retries" event = response_failed_event( "no_accounts", @@ -2510,6 +2585,7 @@ async def _write_stream_preflight_error( ) async def _refresh_usage(self, repos: ProxyRepositories, accounts: list[Account]) -> None: + logger.debug("Refreshing usage snapshot for rate limit status account_count=%s", len(accounts)) latest_usage = await repos.usage.latest_by_account(window="primary") updater = UsageUpdater(repos.usage, repos.accounts, repos.additional_usage) await updater.refresh_accounts(accounts, latest_usage) @@ -2678,6 +2754,7 @@ async def _ensure_fresh( force: bool = False, timeout_seconds: float | None = None, ) -> Account: + logger.debug("Ensuring account freshness account_id=%s force=%s", account.id, force) async with self._repo_factory() as repos: auth_manager = AuthManager(repos.accounts) token = push_token_refresh_timeout_override(timeout_seconds) @@ -2753,6 +2830,12 @@ async def _handle_stream_error( error: UpstreamError, code: str, ) -> None: + logger.info( + "Handling upstream proxy error account_id=%s code=%s reset_at=%s", + account.id, + code, + error.get("resets_at"), + ) if code in {"rate_limit_exceeded", "usage_limit_reached"}: await self._load_balancer.mark_rate_limit(account, error) return diff --git a/app/modules/settings/api.py b/app/modules/settings/api.py index 9265e6f4..f78ba957 100644 --- a/app/modules/settings/api.py +++ b/app/modules/settings/api.py @@ -81,6 +81,7 @@ async def get_settings( routing_strategy=settings.routing_strategy, openai_cache_affinity_max_age_seconds=settings.openai_cache_affinity_max_age_seconds, import_without_overwrite=settings.import_without_overwrite, + http_proxy_url=settings.http_proxy_url, totp_required_on_login=settings.totp_required_on_login, totp_configured=settings.totp_configured, api_key_auth_enabled=settings.api_key_auth_enabled, @@ -115,6 +116,7 @@ async def update_settings( if payload.import_without_overwrite is not None else current.import_without_overwrite ), + http_proxy_url=payload.http_proxy_url if "http_proxy_url" in payload.model_fields_set else current.http_proxy_url, totp_required_on_login=( payload.totp_required_on_login if payload.totp_required_on_login is not None @@ -138,6 +140,7 @@ async def update_settings( routing_strategy=updated.routing_strategy, openai_cache_affinity_max_age_seconds=updated.openai_cache_affinity_max_age_seconds, import_without_overwrite=updated.import_without_overwrite, + http_proxy_url=updated.http_proxy_url, totp_required_on_login=updated.totp_required_on_login, totp_configured=updated.totp_configured, api_key_auth_enabled=updated.api_key_auth_enabled, diff --git a/app/modules/settings/repository.py b/app/modules/settings/repository.py index d9842ca1..8d4f7940 100644 --- a/app/modules/settings/repository.py +++ b/app/modules/settings/repository.py @@ -7,6 +7,7 @@ from app.db.models import DashboardSettings _SETTINGS_ID = 1 +_UNSET = object() class SettingsRepository: @@ -26,6 +27,7 @@ async def get_or_create(self) -> DashboardSettings: routing_strategy="usage_weighted", openai_cache_affinity_max_age_seconds=get_settings().openai_cache_affinity_max_age_seconds, import_without_overwrite=False, + http_proxy_url=None, totp_required_on_login=False, password_hash=None, api_key_auth_enabled=False, @@ -53,6 +55,7 @@ async def update( routing_strategy: str | None = None, openai_cache_affinity_max_age_seconds: int | None = None, import_without_overwrite: bool | None = None, + http_proxy_url: str | None | object = _UNSET, totp_required_on_login: bool | None = None, api_key_auth_enabled: bool | None = None, ) -> DashboardSettings: @@ -69,6 +72,8 @@ async def update( settings.openai_cache_affinity_max_age_seconds = openai_cache_affinity_max_age_seconds if import_without_overwrite is not None: settings.import_without_overwrite = import_without_overwrite + if http_proxy_url is not _UNSET: + settings.http_proxy_url = http_proxy_url if isinstance(http_proxy_url, str) or http_proxy_url is None else None if totp_required_on_login is not None: settings.totp_required_on_login = totp_required_on_login if api_key_auth_enabled is not None: diff --git a/app/modules/settings/schemas.py b/app/modules/settings/schemas.py index 12ec8b40..cc5a7be2 100644 --- a/app/modules/settings/schemas.py +++ b/app/modules/settings/schemas.py @@ -1,7 +1,9 @@ from __future__ import annotations from pydantic import Field +from pydantic import field_validator +from app.core.config.proxy import normalize_http_proxy_url from app.modules.shared.schemas import DashboardModel @@ -12,6 +14,7 @@ class DashboardSettingsResponse(DashboardModel): routing_strategy: str = Field(pattern=r"^(usage_weighted|round_robin)$") openai_cache_affinity_max_age_seconds: int = Field(gt=0) import_without_overwrite: bool + http_proxy_url: str | None = None totp_required_on_login: bool totp_configured: bool api_key_auth_enabled: bool @@ -27,9 +30,17 @@ class DashboardSettingsUpdateRequest(DashboardModel): routing_strategy: str | None = Field(default=None, pattern=r"^(usage_weighted|round_robin)$") openai_cache_affinity_max_age_seconds: int | None = Field(default=None, gt=0) import_without_overwrite: bool | None = None + http_proxy_url: str | None = None totp_required_on_login: bool | None = None api_key_auth_enabled: bool | None = None - + @field_validator("http_proxy_url", mode="before") + @classmethod + def _normalize_http_proxy_url(cls, value: object) -> str | None: + if value is None: + return None + if isinstance(value, str): + return normalize_http_proxy_url(value) + raise TypeError("http_proxy_url must be a string") class RuntimeConnectAddressResponse(DashboardModel): - connect_address: str + connect_address: str \ No newline at end of file diff --git a/app/modules/settings/service.py b/app/modules/settings/service.py index 4ac2486d..c8f532d8 100644 --- a/app/modules/settings/service.py +++ b/app/modules/settings/service.py @@ -13,6 +13,7 @@ class DashboardSettingsData: routing_strategy: str openai_cache_affinity_max_age_seconds: int import_without_overwrite: bool + http_proxy_url: str | None totp_required_on_login: bool totp_configured: bool api_key_auth_enabled: bool @@ -26,6 +27,7 @@ class DashboardSettingsUpdateData: routing_strategy: str openai_cache_affinity_max_age_seconds: int import_without_overwrite: bool + http_proxy_url: str | None totp_required_on_login: bool api_key_auth_enabled: bool @@ -43,6 +45,7 @@ async def get_settings(self) -> DashboardSettingsData: routing_strategy=row.routing_strategy, openai_cache_affinity_max_age_seconds=row.openai_cache_affinity_max_age_seconds, import_without_overwrite=row.import_without_overwrite, + http_proxy_url=row.http_proxy_url, totp_required_on_login=row.totp_required_on_login, totp_configured=row.totp_secret_encrypted is not None, api_key_auth_enabled=row.api_key_auth_enabled, @@ -59,6 +62,7 @@ async def update_settings(self, payload: DashboardSettingsUpdateData) -> Dashboa routing_strategy=payload.routing_strategy, openai_cache_affinity_max_age_seconds=payload.openai_cache_affinity_max_age_seconds, import_without_overwrite=payload.import_without_overwrite, + http_proxy_url=payload.http_proxy_url, totp_required_on_login=payload.totp_required_on_login, api_key_auth_enabled=payload.api_key_auth_enabled, ) @@ -69,6 +73,7 @@ async def update_settings(self, payload: DashboardSettingsUpdateData) -> Dashboa routing_strategy=row.routing_strategy, openai_cache_affinity_max_age_seconds=row.openai_cache_affinity_max_age_seconds, import_without_overwrite=row.import_without_overwrite, + http_proxy_url=row.http_proxy_url, totp_required_on_login=row.totp_required_on_login, totp_configured=row.totp_secret_encrypted is not None, api_key_auth_enabled=row.api_key_auth_enabled, diff --git a/app/modules/usage/updater.py b/app/modules/usage/updater.py index 61bcd4f1..6827a8f1 100644 --- a/app/modules/usage/updater.py +++ b/app/modules/usage/updater.py @@ -666,7 +666,8 @@ def _reset_at(reset_at: int | None, reset_after_seconds: int | None, now_epoch: # The usage endpoint can return 403 for accounts that are still otherwise usable # for proxy traffic, so treat it as a refresh failure instead of a permanent # account-level deactivation signal. -_DEACTIVATING_USAGE_STATUS_CODES = {402, 404} + +_DEACTIVATING_USAGE_STATUS_CODES = {} def _should_deactivate_for_usage_error(status_code: int) -> bool: diff --git a/docker-compose.yml b/docker-compose.yml index ae5c6a36..99529e80 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -33,6 +33,7 @@ services: restart: unless-stopped frontend: + restart: unless-stopped build: context: ./frontend dockerfile_inline: | diff --git a/frontend/src/__integration__/accounts-flow.test.tsx b/frontend/src/__integration__/accounts-flow.test.tsx index e152361b..62cc03cd 100644 --- a/frontend/src/__integration__/accounts-flow.test.tsx +++ b/frontend/src/__integration__/accounts-flow.test.tsx @@ -1,6 +1,6 @@ import { screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import App from "@/App"; import { renderWithProviders } from "@/test/utils"; @@ -32,4 +32,33 @@ describe("accounts flow integration", () => { }); } }); + + it("supports batch import and auth archive download", async () => { + const user = userEvent.setup({ delay: null }); + const createObjectURL = vi.spyOn(URL, "createObjectURL").mockReturnValue("blob:test"); + const revokeObjectURL = vi.spyOn(URL, "revokeObjectURL").mockImplementation(() => {}); + const clickSpy = vi.spyOn(HTMLAnchorElement.prototype, "click").mockImplementation(() => {}); + + window.history.pushState({}, "", "/accounts"); + renderWithProviders(); + + expect(await screen.findByRole("heading", { name: "Accounts" })).toBeInTheDocument(); + expect((await screen.findAllByText("primary@example.com")).length).toBeGreaterThan(0); + + await user.click(screen.getByRole("button", { name: "Import" })); + await user.upload(screen.getByLabelText("Files"), [ + new File(["{}"], "alpha.json", { type: "application/json" }), + new File(["{}"], "beta.json", { type: "application/json" }), + ]); + expect(screen.getByText("2 files selected")).toBeInTheDocument(); + await user.click(screen.getByRole("button", { name: "Close" })); + + await user.click(screen.getByRole("button", { name: "All Auth ZIP" })); + + await waitFor(() => { + expect(createObjectURL).toHaveBeenCalledTimes(1); + expect(clickSpy).toHaveBeenCalledTimes(1); + expect(revokeObjectURL).toHaveBeenCalledWith("blob:test"); + }); + }); }); diff --git a/frontend/src/features/accounts/api.ts b/frontend/src/features/accounts/api.ts index 0ab45d35..4f6bfeef 100644 --- a/frontend/src/features/accounts/api.ts +++ b/frontend/src/features/accounts/api.ts @@ -1,8 +1,8 @@ -import { del, get, post } from "@/lib/api-client"; +import { del, get, getBlob, post } from "@/lib/api-client"; import { AccountActionResponseSchema, - AccountImportResponseSchema, + AccountImportBatchResponseSchema, AccountsResponseSchema, AccountTrendsResponseSchema, ManualOauthCallbackRequestSchema, @@ -22,14 +22,20 @@ export function listAccounts() { return get(ACCOUNTS_BASE_PATH, AccountsResponseSchema); } -export function importAccount(file: File) { +export function importAccounts(files: File[]) { const formData = new FormData(); - formData.append("auth_json", file); - return post(`${ACCOUNTS_BASE_PATH}/import`, AccountImportResponseSchema, { + for (const file of files) { + formData.append("auth_json", file); + } + return post(`${ACCOUNTS_BASE_PATH}/import/batch`, AccountImportBatchResponseSchema, { body: formData, }); } +export function downloadAccountsAuthArchive() { + return getBlob(`${ACCOUNTS_BASE_PATH}/export`); +} + export function pauseAccount(accountId: string) { return post( `${ACCOUNTS_BASE_PATH}/${encodeURIComponent(accountId)}/pause`, diff --git a/frontend/src/features/accounts/components/account-list.test.tsx b/frontend/src/features/accounts/components/account-list.test.tsx index f4a60c49..ad56c840 100644 --- a/frontend/src/features/accounts/components/account-list.test.tsx +++ b/frontend/src/features/accounts/components/account-list.test.tsx @@ -8,6 +8,7 @@ describe("AccountList", () => { it("renders items and filters by search", async () => { const user = userEvent.setup(); const onSelect = vi.fn(); + const onDownloadAuthExport = vi.fn(); render( { selectedAccountId="acc-1" onSelect={onSelect} onOpenImport={() => {}} + onDownloadAuthExport={onDownloadAuthExport} onOpenOauth={() => {}} />, ); @@ -45,6 +47,9 @@ describe("AccountList", () => { await user.click(screen.getByText("secondary@example.com")); expect(onSelect).toHaveBeenCalledWith("acc-2"); + + await user.click(screen.getByRole("button", { name: "All Auth ZIP" })); + expect(onDownloadAuthExport).toHaveBeenCalledTimes(1); }); it("shows empty state when no items match filter", async () => { @@ -65,6 +70,7 @@ describe("AccountList", () => { selectedAccountId={null} onSelect={() => {}} onOpenImport={() => {}} + onDownloadAuthExport={() => {}} onOpenOauth={() => {}} />, ); @@ -105,6 +111,7 @@ describe("AccountList", () => { selectedAccountId={null} onSelect={() => {}} onOpenImport={() => {}} + onDownloadAuthExport={() => {}} onOpenOauth={() => {}} />, ); diff --git a/frontend/src/features/accounts/components/account-list.tsx b/frontend/src/features/accounts/components/account-list.tsx index d44d6d0d..937493c7 100644 --- a/frontend/src/features/accounts/components/account-list.tsx +++ b/frontend/src/features/accounts/components/account-list.tsx @@ -1,4 +1,4 @@ -import { ChevronDown, ChevronUp, Plus, Search, Upload } from "lucide-react"; +import { Download,ChevronDown, ChevronUp, Plus, Search, Upload} from "lucide-react"; import { useMemo, useState } from "react"; import { Button } from "@/components/ui/button"; @@ -23,6 +23,8 @@ export type AccountListProps = { selectedAccountId: string | null; onSelect: (accountId: string) => void; onOpenImport: () => void; + onDownloadAuthExport: () => void; + downloadBusy?: boolean; onOpenOauth: () => void; }; @@ -31,6 +33,8 @@ export function AccountList({ selectedAccountId, onSelect, onOpenImport, + onDownloadAuthExport, + downloadBusy = false, onOpenOauth, }: AccountListProps) { const [search, setSearch] = useState(""); @@ -82,11 +86,22 @@ export function AccountList({ -
+
+
@@ -107,9 +117,20 @@ export function AccountsPage() { open={importDialog.open} busy={importMutation.isPending} error={getErrorMessageOrNull(importMutation.error)} - onOpenChange={importDialog.onOpenChange} - onImport={async (file) => { - await importMutation.mutateAsync(file); + result={lastImportResult} + onOpenChange={(open) => { + if (!open) { + setLastImportResult(null); + } + importDialog.onOpenChange(open); + }} + onImport={async (files) => { + const result = await importMutation.mutateAsync(files); + setLastImportResult(result); + if (result.imported.length > 0) { + await accountsQuery.refetch(); + } + return result; }} /> diff --git a/frontend/src/features/accounts/components/import-dialog.test.tsx b/frontend/src/features/accounts/components/import-dialog.test.tsx new file mode 100644 index 00000000..6ef421a1 --- /dev/null +++ b/frontend/src/features/accounts/components/import-dialog.test.tsx @@ -0,0 +1,75 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, expect, it, vi } from "vitest"; + +import { ImportDialog } from "@/features/accounts/components/import-dialog"; + +describe("ImportDialog", () => { + it("submits multiple files and closes when all imports succeed", async () => { + const user = userEvent.setup(); + const onImport = vi.fn().mockResolvedValue({ + imported: [ + { + filename: "one.json", + accountId: "acc-1", + email: "one@example.com", + planType: "plus", + status: "active", + refreshedOnImport: false, + }, + ], + failed: [], + }); + const onOpenChange = vi.fn(); + + render( + , + ); + + const files = [ + new File(["{}"], "one.json", { type: "application/json" }), + new File(["{}"], "two.json", { type: "application/json" }), + ]; + + await user.upload(screen.getByLabelText("Files"), files); + expect(screen.getByText("2 files selected")).toBeInTheDocument(); + + await user.click(screen.getByRole("button", { name: "Import" })); + + expect(onImport).toHaveBeenCalledWith(files); + expect(onOpenChange).toHaveBeenCalledWith(false); + }); + + it("renders import failures returned by the batch endpoint", () => { + render( + {}} + onImport={vi.fn()} + />, + ); + + expect(screen.getByText("Imported 0 files, 1 failed.")).toBeInTheDocument(); + expect(screen.getByText("broken.json:")).toBeInTheDocument(); + expect(screen.getByText(/Invalid auth\.json payload/)).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/features/accounts/components/import-dialog.tsx b/frontend/src/features/accounts/components/import-dialog.tsx index 8df3bd05..e4a83082 100644 --- a/frontend/src/features/accounts/components/import-dialog.tsx +++ b/frontend/src/features/accounts/components/import-dialog.tsx @@ -1,6 +1,7 @@ import { useState } from "react"; import type { FormEvent } from "react"; +import { AlertMessage } from "@/components/alert-message"; import { Button } from "@/components/ui/button"; import { Dialog, @@ -12,51 +13,69 @@ import { } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import type { AccountImportBatchResponse } from "@/features/accounts/schemas"; export type ImportDialogProps = { open: boolean; busy: boolean; error: string | null; + result: AccountImportBatchResponse | null; onOpenChange: (open: boolean) => void; - onImport: (file: File) => Promise; + onImport: (files: File[]) => Promise; }; export function ImportDialog({ open, busy, error, + result, onOpenChange, onImport, }: ImportDialogProps) { - const [file, setFile] = useState(null); + const [files, setFiles] = useState([]); const handleSubmit = async (event: FormEvent) => { event.preventDefault(); - if (!file) { + if (files.length === 0) { return; } - await onImport(file); - onOpenChange(false); - setFile(null); + const importResult = await onImport(files); + setFiles([]); + if (importResult.failed.length === 0) { + onOpenChange(false); + } }; return ( - Import auth.json - Upload an exported account auth.json file. + Import auth.json files + Upload one or more exported account auth.json files in a single batch.
- + setFile(event.target.files?.[0] ?? null)} + multiple + onChange={(event) => setFiles(Array.from(event.target.files ?? []))} /> +

+ {files.length === 0 + ? "Select one or more auth.json files." + : `${files.length} file${files.length === 1 ? "" : "s"} selected`} +

+ {files.length > 0 ? ( +
    + {files.map((file) => ( +
  • {file.name}
  • + ))} +
+ ) : null}
{error ? ( @@ -65,8 +84,27 @@ export function ImportDialog({

) : null} + {result ? ( +
+ 0 ? "error" : "success"}> + {result.failed.length > 0 + ? `Imported ${result.imported.length} file${result.imported.length === 1 ? "" : "s"}, ${result.failed.length} failed.` + : `Imported ${result.imported.length} file${result.imported.length === 1 ? "" : "s"} successfully.`} + + {result.failed.length > 0 ? ( +
    + {result.failed.map((failure) => ( +
  • + {failure.filename ?? "Unknown file"}: {failure.message} +
  • + ))} +
+ ) : null} +
+ ) : null} + - diff --git a/frontend/src/features/accounts/hooks/use-accounts.test.ts b/frontend/src/features/accounts/hooks/use-accounts.test.ts index 45d822af..6a308eba 100644 --- a/frontend/src/features/accounts/hooks/use-accounts.test.ts +++ b/frontend/src/features/accounts/hooks/use-accounts.test.ts @@ -4,6 +4,11 @@ import { createElement, type PropsWithChildren } from "react"; import { describe, expect, it, vi } from "vitest"; import { useAccounts } from "@/features/accounts/hooks/use-accounts"; +import { downloadBlob } from "@/lib/download"; + +vi.mock("@/lib/download", () => ({ + downloadBlob: vi.fn(), +})); function createTestQueryClient(): QueryClient { return new QueryClient({ @@ -38,9 +43,18 @@ describe("useAccounts", () => { await result.current.resumeMutation.mutateAsync(firstAccountId as string); const imported = await result.current.importMutation.mutateAsync( - new File(["{}"], "auth.json", { type: "application/json" }), + [ + new File(["{}"], "auth-1.json", { type: "application/json" }), + new File(["{}"], "auth-2.json", { type: "application/json" }), + ], ); - await result.current.deleteMutation.mutateAsync(imported.accountId); + expect(imported.imported).toHaveLength(2); + + await result.current.exportAuthArchiveMutation.mutateAsync(); + expect(downloadBlob).toHaveBeenCalledTimes(1); + expect(downloadBlob).toHaveBeenCalledWith(expect.anything(), "auth-export-test.zip"); + + await result.current.deleteMutation.mutateAsync(imported.imported[0]?.accountId ?? ""); await waitFor(() => { expect(invalidateSpy).toHaveBeenCalledWith({ queryKey: ["accounts", "list"] }); diff --git a/frontend/src/features/accounts/hooks/use-accounts.ts b/frontend/src/features/accounts/hooks/use-accounts.ts index e4983d80..bb5c9c25 100644 --- a/frontend/src/features/accounts/hooks/use-accounts.ts +++ b/frontend/src/features/accounts/hooks/use-accounts.ts @@ -3,12 +3,14 @@ import { toast } from "sonner"; import { deleteAccount, + downloadAccountsAuthArchive, getAccountTrends, - importAccount, + importAccounts, listAccounts, pauseAccount, reactivateAccount, } from "@/features/accounts/api"; +import { downloadBlob } from "@/lib/download"; function invalidateAccountRelatedQueries(queryClient: ReturnType) { void queryClient.invalidateQueries({ queryKey: ["accounts", "list"] }); @@ -24,16 +26,43 @@ export function useAccountMutations() { const queryClient = useQueryClient(); const importMutation = useMutation({ - mutationFn: importAccount, - onSuccess: () => { - toast.success("Account imported"); - invalidateAccountRelatedQueries(queryClient); + mutationFn: importAccounts, + onSuccess: (result) => { + const importedCount = result.imported.length; + const failedCount = result.failed.length; + + if (importedCount > 0) { + invalidateAccountRelatedQueries(queryClient); + } + + if (failedCount === 0) { + toast.success(importedCount === 1 ? "Imported 1 account" : `Imported ${importedCount} accounts`); + return; + } + + if (importedCount === 0) { + toast.error(failedCount === 1 ? "Import failed for 1 file" : `Import failed for ${failedCount} files`); + return; + } + + toast.success(`Imported ${importedCount} account${importedCount === 1 ? "" : "s"}, ${failedCount} failed`); }, onError: (error: Error) => { toast.error(error.message || "Import failed"); }, }); + const exportAuthArchiveMutation = useMutation({ + mutationFn: downloadAccountsAuthArchive, + onSuccess: ({ blob, filename }) => { + downloadBlob(blob, filename ?? "auth-export.zip"); + toast.success("Downloaded auth archive"); + }, + onError: (error: Error) => { + toast.error(error.message || "Download failed"); + }, + }); + const pauseMutation = useMutation({ mutationFn: pauseAccount, onSuccess: () => { @@ -67,7 +96,7 @@ export function useAccountMutations() { }, }); - return { importMutation, pauseMutation, resumeMutation, deleteMutation }; + return { importMutation, exportAuthArchiveMutation, pauseMutation, resumeMutation, deleteMutation }; } export function useAccountTrends(accountId: string | null) { diff --git a/frontend/src/features/accounts/schemas.test.ts b/frontend/src/features/accounts/schemas.test.ts index 8b88ce1e..09fa21cb 100644 --- a/frontend/src/features/accounts/schemas.test.ts +++ b/frontend/src/features/accounts/schemas.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it } from "vitest"; import { + AccountImportBatchResponseSchema, AccountSummarySchema, ImportStateSchema, OAuthStateSchema, @@ -105,3 +106,30 @@ describe("ImportStateSchema", () => { ).toBe(true); }); }); + +describe("AccountImportBatchResponseSchema", () => { + it("parses imported and failed batch results", () => { + const parsed = AccountImportBatchResponseSchema.parse({ + imported: [ + { + filename: "auth-1.json", + accountId: "acc-1", + email: "user@example.com", + planType: "plus", + status: "active", + refreshedOnImport: true, + }, + ], + failed: [ + { + filename: "broken.json", + code: "invalid_auth_json", + message: "Invalid auth.json payload", + }, + ], + }); + + expect(parsed.imported[0]?.refreshedOnImport).toBe(true); + expect(parsed.failed[0]?.code).toBe("invalid_auth_json"); + }); +}); diff --git a/frontend/src/features/accounts/schemas.ts b/frontend/src/features/accounts/schemas.ts index 030292ef..a578ff4a 100644 --- a/frontend/src/features/accounts/schemas.ts +++ b/frontend/src/features/accounts/schemas.ts @@ -75,10 +75,23 @@ export const AccountsResponseSchema = z.object({ }); export const AccountImportResponseSchema = z.object({ + filename: z.string().nullable().optional(), accountId: z.string(), email: z.string(), planType: z.string(), status: z.string(), + refreshedOnImport: z.boolean().optional().default(false), +}); + +export const AccountImportFailureSchema = z.object({ + filename: z.string().nullable().optional(), + code: z.string(), + message: z.string(), +}); + +export const AccountImportBatchResponseSchema = z.object({ + imported: z.array(AccountImportResponseSchema), + failed: z.array(AccountImportFailureSchema), }); export const AccountActionResponseSchema = z.object({ @@ -151,6 +164,9 @@ export type AccountSummary = z.infer; export type AccountAdditionalWindow = z.infer; export type AccountAdditionalQuota = z.infer; export type AccountTrendsResponse = z.infer; +export type AccountImportResponse = z.infer; +export type AccountImportFailure = z.infer; +export type AccountImportBatchResponse = z.infer; export type OauthStartResponse = z.infer; export type OauthStatusResponse = z.infer; export type ManualOauthCallbackResponse = z.infer; diff --git a/frontend/src/features/settings/components/routing-settings.test.tsx b/frontend/src/features/settings/components/routing-settings.test.tsx index fda3af5d..7f73bc3b 100644 --- a/frontend/src/features/settings/components/routing-settings.test.tsx +++ b/frontend/src/features/settings/components/routing-settings.test.tsx @@ -12,6 +12,7 @@ const BASE_SETTINGS: DashboardSettings = { routingStrategy: "usage_weighted", openaiCacheAffinityMaxAgeSeconds: 300, importWithoutOverwrite: false, + httpProxyUrl: null, totpRequiredOnLogin: false, totpConfigured: false, apiKeyAuthEnabled: true, diff --git a/frontend/src/features/settings/components/routing-settings.tsx b/frontend/src/features/settings/components/routing-settings.tsx index d1630cf1..8169735e 100644 --- a/frontend/src/features/settings/components/routing-settings.tsx +++ b/frontend/src/features/settings/components/routing-settings.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Route } from "lucide-react"; import { Button } from "@/components/ui/button"; @@ -21,9 +21,16 @@ export type RoutingSettingsProps = { }; export function RoutingSettings({ settings, busy, onSave }: RoutingSettingsProps) { - const [cacheAffinityTtl, setCacheAffinityTtl] = useState( - String(settings.openaiCacheAffinityMaxAgeSeconds), - ); + const [httpProxyUrl, setHttpProxyUrl] = useState(settings.httpProxyUrl ?? ""); + const [cacheAffinityTtl, setCacheAffinityTtl] = useState(String(settings.openaiCacheAffinityMaxAgeSeconds)); + + useEffect(() => { + setHttpProxyUrl(settings.httpProxyUrl ?? ""); + }, [settings.httpProxyUrl]); + + useEffect(() => { + setCacheAffinityTtl(String(settings.openaiCacheAffinityMaxAgeSeconds)); + }, [settings.openaiCacheAffinityMaxAgeSeconds]); const save = (patch: Partial) => void onSave(buildSettingsUpdateRequest(settings, patch)); @@ -33,6 +40,14 @@ export function RoutingSettings({ settings, busy, onSave }: RoutingSettingsProps const cacheAffinityTtlChanged = cacheAffinityTtlValid && parsedCacheAffinityTtl !== settings.openaiCacheAffinityMaxAgeSeconds; + const trimmedHttpProxyUrl = httpProxyUrl.trim(); + const savedHttpProxyUrl = settings.httpProxyUrl ?? ""; + const proxyDirty = trimmedHttpProxyUrl !== savedHttpProxyUrl; + + const handleProxySave = () => { + save({ httpProxyUrl: trimmedHttpProxyUrl.length > 0 ? trimmedHttpProxyUrl : null }); + }; + return (
@@ -152,6 +167,45 @@ export function RoutingSettings({ settings, busy, onSave }: RoutingSettingsProps
+ +
+
+

HTTP proxy

+

+ Route outgoing backend requests through an HTTP or HTTPS proxy. +

+
+ { + event.preventDefault(); + handleProxySave(); + }} + > + setHttpProxyUrl(event.target.value)} + /> + + + +
diff --git a/frontend/src/features/settings/components/totp-settings.test.tsx b/frontend/src/features/settings/components/totp-settings.test.tsx index b77f6c9a..70ed74ad 100644 --- a/frontend/src/features/settings/components/totp-settings.test.tsx +++ b/frontend/src/features/settings/components/totp-settings.test.tsx @@ -24,6 +24,7 @@ const baseSettings = { routingStrategy: "usage_weighted" as const, openaiCacheAffinityMaxAgeSeconds: 300, importWithoutOverwrite: false, + httpProxyUrl: null, totpRequiredOnLogin: false, totpConfigured: false, apiKeyAuthEnabled: true, @@ -109,6 +110,7 @@ describe("TotpSettings", () => { routingStrategy: "usage_weighted", openaiCacheAffinityMaxAgeSeconds: 300, importWithoutOverwrite: false, + httpProxyUrl: null, totpRequiredOnLogin: true, apiKeyAuthEnabled: true, }); diff --git a/frontend/src/features/settings/schemas.test.ts b/frontend/src/features/settings/schemas.test.ts index 1ed32ea8..f630239e 100644 --- a/frontend/src/features/settings/schemas.test.ts +++ b/frontend/src/features/settings/schemas.test.ts @@ -14,6 +14,7 @@ describe("DashboardSettingsSchema", () => { routingStrategy: "round_robin", openaiCacheAffinityMaxAgeSeconds: 300, importWithoutOverwrite: true, + httpProxyUrl: "http://proxy.example:8080", totpRequiredOnLogin: true, totpConfigured: false, apiKeyAuthEnabled: true, @@ -24,6 +25,7 @@ describe("DashboardSettingsSchema", () => { expect(parsed.routingStrategy).toBe("round_robin"); expect(parsed.openaiCacheAffinityMaxAgeSeconds).toBe(300); expect(parsed.importWithoutOverwrite).toBe(true); + expect(parsed.httpProxyUrl).toBe("http://proxy.example:8080"); expect(parsed.apiKeyAuthEnabled).toBe(true); }); }); @@ -37,6 +39,7 @@ describe("SettingsUpdateRequestSchema", () => { routingStrategy: "usage_weighted", openaiCacheAffinityMaxAgeSeconds: 120, importWithoutOverwrite: true, + httpProxyUrl: "https://proxy.example:8443", totpRequiredOnLogin: true, apiKeyAuthEnabled: false, }); @@ -45,6 +48,7 @@ describe("SettingsUpdateRequestSchema", () => { expect(parsed.upstreamStreamTransport).toBe("websocket"); expect(parsed.importWithoutOverwrite).toBe(true); expect(parsed.routingStrategy).toBe("usage_weighted"); + expect(parsed.httpProxyUrl).toBe("https://proxy.example:8443"); expect(parsed.totpRequiredOnLogin).toBe(true); expect(parsed.apiKeyAuthEnabled).toBe(false); }); @@ -57,6 +61,7 @@ describe("SettingsUpdateRequestSchema", () => { expect(parsed.upstreamStreamTransport).toBeUndefined(); expect(parsed.importWithoutOverwrite).toBeUndefined(); + expect(parsed.httpProxyUrl).toBeUndefined(); expect(parsed.totpRequiredOnLogin).toBeUndefined(); expect(parsed.apiKeyAuthEnabled).toBeUndefined(); expect(parsed.openaiCacheAffinityMaxAgeSeconds).toBeUndefined(); diff --git a/frontend/src/features/settings/schemas.ts b/frontend/src/features/settings/schemas.ts index 53114d04..8857ba7d 100644 --- a/frontend/src/features/settings/schemas.ts +++ b/frontend/src/features/settings/schemas.ts @@ -10,6 +10,7 @@ export const DashboardSettingsSchema = z.object({ routingStrategy: RoutingStrategySchema, openaiCacheAffinityMaxAgeSeconds: z.number().int().positive(), importWithoutOverwrite: z.boolean(), + httpProxyUrl: z.string().url().nullable(), totpRequiredOnLogin: z.boolean(), totpConfigured: z.boolean(), apiKeyAuthEnabled: z.boolean(), @@ -22,6 +23,7 @@ export const SettingsUpdateRequestSchema = z.object({ routingStrategy: RoutingStrategySchema.optional(), openaiCacheAffinityMaxAgeSeconds: z.number().int().positive().optional(), importWithoutOverwrite: z.boolean().optional(), + httpProxyUrl: z.string().url().nullable().optional(), totpRequiredOnLogin: z.boolean().optional(), apiKeyAuthEnabled: z.boolean().optional(), }); diff --git a/frontend/src/lib/api-client.ts b/frontend/src/lib/api-client.ts index 153ea00b..68fa9bca 100644 --- a/frontend/src/lib/api-client.ts +++ b/frontend/src/lib/api-client.ts @@ -10,6 +10,12 @@ type RequestOptions = { credentials?: RequestCredentials; }; +export type BlobResponse = { + blob: Blob; + filename: string | null; + contentType: string | null; +}; + const JSON_CONTENT_TYPE = "application/json"; const EMPTY_RESPONSE_STATUS = new Set([204, 205]); @@ -113,6 +119,29 @@ function parseApiErrorPayload(payload: unknown): { }; } +function parseFilenameFromContentDisposition(header: string | null): string | null { + if (!header) { + return null; + } + + const utf8Match = header.match(/filename\*=UTF-8''([^;]+)/i); + if (utf8Match?.[1]) { + try { + return decodeURIComponent(utf8Match[1]); + } catch { + return utf8Match[1]; + } + } + + const quotedMatch = header.match(/filename=\"([^\"]+)\"/i); + if (quotedMatch?.[1]) { + return quotedMatch[1]; + } + + const bareMatch = header.match(/filename=([^;]+)/i); + return bareMatch?.[1]?.trim() ?? null; +} + async function request( method: HttpMethod, url: string, @@ -195,6 +224,57 @@ async function request( return parsed.data; } +async function requestBlob(method: HttpMethod, url: string, options?: RequestOptions): Promise { + const requestBody = buildRequestBody(options?.body); + const headers = new Headers(options?.headers); + if (requestBody.contentType && !headers.has("Content-Type")) { + headers.set("Content-Type", requestBody.contentType); + } + if (!headers.has("Accept")) { + headers.set("Accept", "application/octet-stream, application/zip;q=0.9, */*;q=0.8"); + } + + let response: Response; + try { + response = await fetch(url, { + method, + body: requestBody.body, + headers, + signal: options?.signal, + credentials: options?.credentials ?? "same-origin", + }); + } catch (error) { + throw new ApiError({ + status: 0, + code: "network_error", + message: error instanceof Error ? error.message : "Network request failed", + details: error, + }); + } + + if (response.status === 401) { + unauthorizedHandler?.(); + } + + if (!response.ok) { + const payload = await readJsonPayload(response); + const parsedError = parseApiErrorPayload(payload); + throw new ApiError({ + status: response.status, + code: parsedError.code, + message: parsedError.message, + details: parsedError.details, + payload, + }); + } + + return { + blob: await response.blob(), + filename: parseFilenameFromContentDisposition(response.headers.get("Content-Disposition")), + contentType: response.headers.get("Content-Type"), + }; +} + export function get( url: string, schema: ZodType, @@ -203,6 +283,10 @@ export function get( return request("GET", url, schema, options); } +export function getBlob(url: string, options?: RequestOptions): Promise { + return requestBlob("GET", url, options); +} + export function post( url: string, schema: ZodType, diff --git a/frontend/src/lib/download.test.ts b/frontend/src/lib/download.test.ts new file mode 100644 index 00000000..3ae257a3 --- /dev/null +++ b/frontend/src/lib/download.test.ts @@ -0,0 +1,22 @@ +import { describe, expect, it, vi } from "vitest"; + +import { downloadBlob } from "@/lib/download"; + +describe("downloadBlob", () => { + it("creates an object URL, clicks a temporary link, and revokes the URL", () => { + const blob = new Blob(["zip"]); + const createObjectURL = vi.spyOn(URL, "createObjectURL").mockReturnValue("blob:test"); + const revokeObjectURL = vi.spyOn(URL, "revokeObjectURL").mockImplementation(() => {}); + const clickSpy = vi.spyOn(HTMLAnchorElement.prototype, "click").mockImplementation(() => {}); + + vi.useFakeTimers(); + downloadBlob(blob, "auth-export.zip"); + vi.runAllTimers(); + + expect(createObjectURL).toHaveBeenCalledWith(blob); + expect(clickSpy).toHaveBeenCalledTimes(1); + expect(revokeObjectURL).toHaveBeenCalledWith("blob:test"); + + vi.useRealTimers(); + }); +}); diff --git a/frontend/src/lib/download.ts b/frontend/src/lib/download.ts new file mode 100644 index 00000000..2b919c8e --- /dev/null +++ b/frontend/src/lib/download.ts @@ -0,0 +1,11 @@ +export function downloadBlob(blob: Blob, filename: string): void { + const url = URL.createObjectURL(blob); + const anchor = document.createElement("a"); + anchor.href = url; + anchor.download = filename; + anchor.style.display = "none"; + document.body.appendChild(anchor); + anchor.click(); + anchor.remove(); + setTimeout(() => URL.revokeObjectURL(url), 0); +} diff --git a/frontend/src/test/mocks/factories.ts b/frontend/src/test/mocks/factories.ts index baad3fd2..e6852dd8 100644 --- a/frontend/src/test/mocks/factories.ts +++ b/frontend/src/test/mocks/factories.ts @@ -1,4 +1,7 @@ import { + AccountImportBatchResponseSchema, + AccountImportFailureSchema, + AccountImportResponseSchema, AccountSummarySchema, AccountTrendsResponseSchema, OauthStartResponseSchema, @@ -6,6 +9,9 @@ import { OauthCompleteResponseSchema, } from "@/features/accounts/schemas"; import type { + AccountImportBatchResponse, + AccountImportFailure, + AccountImportResponse, AccountSummary, AccountTrendsResponse, OauthStartResponse, @@ -40,6 +46,9 @@ export type DashboardAuthSession = AuthSession; export type OauthCompleteResponse = z.infer; export type { + AccountImportBatchResponse, + AccountImportFailure, + AccountImportResponse, AccountSummary, AccountTrendsResponse, DashboardOverview, @@ -98,6 +107,41 @@ export function createDefaultAccounts(): AccountSummary[] { ]; } +export function createAccountImportResponse( + overrides: Partial = {}, +): AccountImportResponse { + return AccountImportResponseSchema.parse({ + filename: "auth.json", + accountId: "acc_imported_1", + email: "imported-1@example.com", + planType: "plus", + status: "active", + refreshedOnImport: false, + ...overrides, + }); +} + +export function createAccountImportFailure( + overrides: Partial = {}, +): AccountImportFailure { + return AccountImportFailureSchema.parse({ + filename: "broken.json", + code: "invalid_auth_json", + message: "Invalid auth.json payload", + ...overrides, + }); +} + +export function createAccountImportBatchResponse( + overrides: Partial = {}, +): AccountImportBatchResponse { + return AccountImportBatchResponseSchema.parse({ + imported: [createAccountImportResponse()], + failed: [], + ...overrides, + }); +} + function createTrendPoints(baseValue: number, count = 28): Array<{ t: string; v: number }> { return Array.from({ length: count }, (_, i) => ({ t: new Date(BASE_TIME.getTime() - (count - i) * 6 * 3600_000).toISOString(), @@ -282,6 +326,7 @@ export function createDashboardSettings(overrides: Partial = routingStrategy: "usage_weighted", openaiCacheAffinityMaxAgeSeconds: 300, importWithoutOverwrite: false, + httpProxyUrl: null, totpRequiredOnLogin: false, totpConfigured: true, apiKeyAuthEnabled: true, diff --git a/frontend/src/test/mocks/handler-coverage.test.ts b/frontend/src/test/mocks/handler-coverage.test.ts index 0d182268..4eed4dc9 100644 --- a/frontend/src/test/mocks/handler-coverage.test.ts +++ b/frontend/src/test/mocks/handler-coverage.test.ts @@ -28,7 +28,8 @@ const EXPECTED_ENDPOINTS = [ "GET /api/request-logs/options", // accounts "GET /api/accounts", - "POST /api/accounts/import", + "POST /api/accounts/import/batch", + "GET /api/accounts/export", "POST /api/accounts/:accountId/pause", "POST /api/accounts/:accountId/reactivate", "GET /api/accounts/:accountId/trends", diff --git a/frontend/src/test/mocks/handlers.ts b/frontend/src/test/mocks/handlers.ts index 0c293a1c..654b3c9f 100644 --- a/frontend/src/test/mocks/handlers.ts +++ b/frontend/src/test/mocks/handlers.ts @@ -3,6 +3,9 @@ import { z } from "zod"; import { LIMIT_TYPES, LIMIT_WINDOWS } from "@/features/api-keys/schemas"; import { + createAccountImportBatchResponse, + createAccountImportFailure, + createAccountImportResponse, createAccountSummary, createAccountTrends, createApiKey, @@ -18,8 +21,10 @@ import { createOauthStatusResponse, createRequestLogFilterOptions, createRequestLogsResponse, + type AccountImportFailure, type AccountSummary, type ApiKey, + type AccountImportResponse, type DashboardAuthSession, type DashboardSettings, type RequestLogEntry, @@ -64,6 +69,7 @@ const SettingsPayloadSchema = z.object({ routingStrategy: z.enum(["usage_weighted", "round_robin"]).optional(), openaiCacheAffinityMaxAgeSeconds: z.number().int().positive().optional(), importWithoutOverwrite: z.boolean().optional(), + httpProxyUrl: z.string().url().nullable().optional(), totpRequiredOnLogin: z.boolean().optional(), totpConfigured: z.boolean().optional(), apiKeyAuthEnabled: z.boolean().optional(), @@ -233,6 +239,10 @@ function findApiKey(keyId: string): ApiKey | undefined { return state.apiKeys.find((item) => item.id === keyId); } +function isUploadedFileLike(value: FormDataEntryValue): value is File { + return typeof value === "object" && value !== null && "name" in value; +} + export const handlers = [ http.get("/health", () => { return HttpResponse.json({ status: "ok" }); @@ -267,20 +277,52 @@ export const handlers = [ return HttpResponse.json({ accounts: state.accounts }); }), - http.post("/api/accounts/import", async () => { - const sequence = state.accounts.length + 1; - const created = createAccountSummary({ - accountId: `acc_imported_${sequence}`, - email: `imported-${sequence}@example.com`, - displayName: `imported-${sequence}@example.com`, - status: "active", - }); - state.accounts = [...state.accounts, created]; - return HttpResponse.json({ - accountId: created.accountId, - email: created.email, - planType: created.planType, - status: created.status, + http.post("/api/accounts/import/batch", async ({ request }) => { + const formData = await request.formData(); + const uploadedFiles = formData.getAll("auth_json").filter(isUploadedFileLike); + + const imported: AccountImportResponse[] = []; + const failed: AccountImportFailure[] = []; + + for (const file of uploadedFiles) { + if (file.name.toLowerCase().includes("invalid")) { + failed.push( + createAccountImportFailure({ + filename: file.name, + message: "Invalid auth.json payload", + }), + ); + continue; + } + + const sequence = state.accounts.length + 1; + const created = createAccountSummary({ + accountId: `acc_imported_${sequence}`, + email: `${file.name.replace(/\.json$/i, "")}@example.com`, + displayName: `${file.name.replace(/\.json$/i, "")}@example.com`, + status: "active", + }); + state.accounts = [...state.accounts, created]; + imported.push( + createAccountImportResponse({ + filename: file.name, + accountId: created.accountId, + email: created.email, + planType: created.planType, + status: created.status, + }), + ); + } + + return HttpResponse.json(createAccountImportBatchResponse({ imported, failed })); + }), + + http.get("/api/accounts/export", () => { + return new HttpResponse(new Uint8Array([80, 75, 3, 4]), { + headers: { + "Content-Type": "application/zip", + "Content-Disposition": 'attachment; filename="auth-export-test.zip"', + }, }); }), diff --git a/frontend/src/utils/formatters.test.ts b/frontend/src/utils/formatters.test.ts index ac8d499b..92ec6486 100644 --- a/frontend/src/utils/formatters.test.ts +++ b/frontend/src/utils/formatters.test.ts @@ -131,13 +131,20 @@ describe("formatters", () => { expect( formatAccessTokenLabel({ access: { expiresAt: "1970-01-01T00:00:00.000Z" }, + refresh: { state: "stored" }, }), - ).toBe("Expired"); + ).toBe("Expired (refresh available)"); expect( formatAccessTokenLabel({ access: { expiresAt: future }, }), ).toBe("Valid (in 2h)"); + expect( + formatAccessTokenLabel({ + access: { expiresAt: null }, + refresh: { state: "stored" }, + }), + ).toBe("Refreshable"); expect( formatRefreshTokenLabel({ diff --git a/frontend/src/utils/formatters.ts b/frontend/src/utils/formatters.ts index 1fcc785e..a64dfcd4 100644 --- a/frontend/src/utils/formatters.ts +++ b/frontend/src/utils/formatters.ts @@ -247,8 +247,9 @@ export function truncateText(value: unknown, maxLen = 80): string { export function formatAccessTokenLabel(auth: AccountAuthStatus | null | undefined): string { const expiresAt = auth?.access?.expiresAt; + const refreshState = auth?.refresh?.state; if (!expiresAt) { - return "Missing"; + return refreshState === "stored" ? "Refreshable" : "Missing"; } const expiresDate = parseDate(expiresAt); if (!expiresDate) { @@ -256,7 +257,7 @@ export function formatAccessTokenLabel(auth: AccountAuthStatus | null | undefine } const diffMs = expiresDate.getTime() - Date.now(); if (diffMs <= 0) { - return "Expired"; + return refreshState === "stored" ? "Expired (refresh available)" : "Expired"; } return `Valid (${formatRelative(diffMs)})`; } diff --git a/jira.sh b/jira.sh new file mode 100644 index 00000000..98061611 --- /dev/null +++ b/jira.sh @@ -0,0 +1,320 @@ +#!/usr/bin/env bash +set -euo pipefail + +JIRA_BASE="https://192.168.214.2:8443" +JIRA_USER="a.markov-buturskiy" +JIRA_PASS="patay228" + +# Твои проекты +PROJECTS="SES,XFIVE,AF,ROOT" + +# Примерное окно аварии +FROM="2026-03-12 05:00" +TO="2026-03-12 05:30" +FILTER_BY_WINDOW=false + +# Кто был assignee/closer при аварийном закрытии +MY_JIRA_USERNAME="a.markov-buturskiy" + +# dry-run сначала true +DRY_RUN=true +CACHE_ENABLED=true +CACHE_TTL_SECONDS=300 +CACHE_DIR="${HOME}/.cache/jira.sh" +ISSUE_CACHE_DIR="${CACHE_DIR}/issue-history" + +api() { + curl -sk -u "$JIRA_USER:$JIRA_PASS" -H 'Content-Type: application/json' "$@" +} + +cache_key_for_jql() { + local key_source="$1" + printf '%s' "$key_source" | shasum -a 256 | awk '{print $1}' +} + +cache_fresh() { + local path="$1" + local now modified age + now="$(date +%s)" + modified="$(stat -f %m "$path" 2>/dev/null || echo 0)" + age=$((now - modified)) + if (( age <= CACHE_TTL_SECONDS )); then + return 0 + fi + return 1 +} + +cache_age_seconds() { + local path="$1" + local now modified + now="$(date +%s)" + modified="$(stat -f %m "$path" 2>/dev/null || echo 0)" + echo $((now - modified)) +} + +get_issue_snapshot_cached() { + local issue="$1" + local out_file="$2" + local issue_cache_file + issue_cache_file="${ISSUE_CACHE_DIR}/${issue}.json" + + if [[ "$CACHE_ENABLED" == "true" && -f "$issue_cache_file" ]] && cache_fresh "$issue_cache_file"; then + cp "$issue_cache_file" "$out_file" + return 0 + fi + + api "${JIRA_BASE}/rest/api/2/issue/${issue}?expand=changelog&fields=comment,status,assignee" > "$out_file" + if [[ "$CACHE_ENABLED" == "true" ]]; then + mkdir -p "$ISSUE_CACHE_DIR" + cp "$out_file" "$issue_cache_file" + fi + return 1 +} + +issue_has_exact_ci_comment() { + local issue_json="$1" + if jq -e 'any(.fields.comment.comments[]?; ((.body // "") | gsub("\r"; "") | gsub("^\\s+|\\s+$"; "")) == "CI")' "$issue_json" >/dev/null; then + return 0 + fi + return 1 +} + +count_exact_ci_comments() { + local issue_json="$1" + jq -r '[.fields.comment.comments[]? | select(((.body // "") | gsub("\r"; "") | gsub("^\\s+|\\s+$"; "")) == "CI")] | length' "$issue_json" +} + +echo "Fetching affected issues by JQL..." +SEARCH_JSON="$(mktemp)" +JQL='comment ~ "CI"' +if [[ "$FILTER_BY_WINDOW" == "true" ]]; then + JQL="${JQL} AND updated >= \"${FROM}\" AND updated <= \"${TO}\"" +fi + +CACHE_KEY="$(cache_key_for_jql "${JIRA_BASE}|${JQL}")" +SEARCH_CACHE_FILE="${CACHE_DIR}/search-${CACHE_KEY}.json" +SEARCH_CACHE_HIT=false + +if [[ "$CACHE_ENABLED" == "true" && -f "$SEARCH_CACHE_FILE" ]] && cache_fresh "$SEARCH_CACHE_FILE"; then + cp "$SEARCH_CACHE_FILE" "$SEARCH_JSON" + SEARCH_CACHE_HIT=true +else + api \ + --get \ + --data-urlencode "jql=${JQL}" \ + --data-urlencode "maxResults=1000" \ + --data-urlencode "fields=key,status,assignee" \ + "${JIRA_BASE}/rest/api/2/search" > "$SEARCH_JSON" + if [[ "$CACHE_ENABLED" == "true" ]]; then + mkdir -p "$CACHE_DIR" + cp "$SEARCH_JSON" "$SEARCH_CACHE_FILE" + fi +fi + +if jq -e '(.errorMessages // []) | length > 0' "$SEARCH_JSON" >/dev/null; then + echo "Jira search returned errors:" >&2 + jq -r '.errorMessages[]' "$SEARCH_JSON" >&2 + rm -f "$SEARCH_JSON" + exit 1 +fi + +if [[ "$DRY_RUN" == "true" ]]; then + echo + echo "DRY_RUN=true: reporting affected issues only (no Jira mutations)." + echo "Projects used: ALL" + echo "Filter by window: ${FILTER_BY_WINDOW}" + if [[ "$SEARCH_CACHE_HIT" == "true" ]]; then + echo "Search cache: hit ($(cache_age_seconds "$SEARCH_CACHE_FILE")s old)" + else + echo "Search cache: miss" + fi + mapfile -t CANDIDATE_ISSUES < <(jq -r '(.issues // [])[]?.key' "$SEARCH_JSON") + CANDIDATE_COUNT="${#CANDIDATE_ISSUES[@]}" + AFFECTED_ISSUES=() + PLAN_ROWS=() + ISSUE_CACHE_HITS=0 + ISSUE_CACHE_MISSES=0 + for issue in "${CANDIDATE_ISSUES[@]}"; do + ISSUE_JSON="$(mktemp)" + if get_issue_snapshot_cached "$issue" "$ISSUE_JSON"; then + ISSUE_CACHE_HITS=$((ISSUE_CACHE_HITS + 1)) + else + ISSUE_CACHE_MISSES=$((ISSUE_CACHE_MISSES + 1)) + fi + + if issue_has_exact_ci_comment "$ISSUE_JSON"; then + AFFECTED_ISSUES+=("$issue") + CI_COMMENT_COUNT="$(count_exact_ci_comments "$ISSUE_JSON")" + + CUR_STATUS=$(jq -r '.fields.status.name // empty' "$ISSUE_JSON") + + if [[ "$FILTER_BY_WINDOW" == "true" ]]; then + LAST_STATUS_EVENT=$(jq -c --arg FROM "$FROM" --arg TO "$TO" ' + .changelog.histories + | map(select(.created >= $FROM and .created <= $TO)) + | map(select(any(.items[]?; .field=="status"))) + | sort_by(.created) + | last + ' "$ISSUE_JSON") + else + LAST_STATUS_EVENT=$(jq -c ' + .changelog.histories + | map(select(any(.items[]?; .field=="status"))) + | sort_by(.created) + | last + ' "$ISSUE_JSON") + fi + + LAST_CLOSER="" + PREV_STATUS="" + PREV_ASSIGNEE="" + if [[ "$LAST_STATUS_EVENT" != "null" && -n "$LAST_STATUS_EVENT" ]]; then + LAST_CLOSER=$(echo "$LAST_STATUS_EVENT" | jq -r '.author.name // .author.key // .author.displayName // empty') + PREV_STATUS=$(echo "$LAST_STATUS_EVENT" | jq -r '.items[] | select(.field=="status") | .fromString // empty') + PREV_ASSIGNEE=$(jq -r --arg EVENT_CREATED "$(echo "$LAST_STATUS_EVENT" | jq -r '.created')" ' + .changelog.histories + | map(select(.created <= $EVENT_CREATED)) + | map(select(any(.items[]?; .field=="assignee"))) + | sort_by(.created) + | last + | .items[]? + | select(.field=="assignee") + | .from // empty + ' "$ISSUE_JSON") + fi + + DO_ROLLBACK=false + TARGET_TRANSITION_NAME="" + if [[ -n "${LAST_CLOSER}" && "${LAST_CLOSER}" == "${MY_JIRA_USERNAME}" ]]; then + DO_ROLLBACK=true + TARGET_TRANSITION_NAME="Reopen Issue" + fi + + PLAN_ROWS+=("${issue}|status=${CUR_STATUS:-}|rollback=${DO_ROLLBACK}|transition=${TARGET_TRANSITION_NAME:-}|restore_assignee=${PREV_ASSIGNEE:-}|delete_ci_comments=${CI_COMMENT_COUNT}|closed_by=${LAST_CLOSER:-}") + fi + rm -f "$ISSUE_JSON" + done + + AFFECTED_COUNT="${#AFFECTED_ISSUES[@]}" + echo "Candidate issues by JQL: ${CANDIDATE_COUNT}" + echo "Issue history cache: hit=${ISSUE_CACHE_HITS}, miss=${ISSUE_CACHE_MISSES}" + echo "Affected issues count (exact CI comment): ${AFFECTED_COUNT}" + if (( AFFECTED_COUNT > 0 )); then + printf '%s\n' "${AFFECTED_ISSUES[@]}" + fi + if (( ${#PLAN_ROWS[@]} > 0 )); then + echo + echo "Plan preview (read-only):" + printf '%s\n' "${PLAN_ROWS[@]}" + fi + rm -f "$SEARCH_JSON" + exit 0 +fi + +jq -r '.issues[].key' "$SEARCH_JSON" | while read -r ISSUE; do + echo + echo "==== $ISSUE ====" + + ISSUE_JSON="$(mktemp)" + api "${JIRA_BASE}/rest/api/2/issue/${ISSUE}?expand=changelog" > "$ISSUE_JSON" + + CUR_STATUS=$(jq -r '.fields.status.name' "$ISSUE_JSON") + CUR_ASSIGNEE=$(jq -r '.fields.assignee.name // .fields.assignee.key // .fields.assignee.displayName // empty' "$ISSUE_JSON") + + # Последнее изменение статуса в окне аварии + LAST_STATUS_EVENT=$(jq -c --arg FROM "$FROM" --arg TO "$TO" ' + .changelog.histories + | map(select(.created >= $FROM and .created <= $TO)) + | map(select(any(.items[]?; .field=="status"))) + | sort_by(.created) + | last + ' "$ISSUE_JSON") + + if [[ "$LAST_STATUS_EVENT" == "null" || -z "$LAST_STATUS_EVENT" ]]; then + echo "No status change in target window, only deleting CI comments" + LAST_CLOSER="" + PREV_STATUS="" + PREV_ASSIGNEE="" + else + LAST_CLOSER=$(echo "$LAST_STATUS_EVENT" | jq -r '.author.name // .author.key // .author.displayName // empty') + PREV_STATUS=$(echo "$LAST_STATUS_EVENT" | jq -r '.items[] | select(.field=="status") | .fromString // empty') + + # Ищем assignee ДО этого события + PREV_ASSIGNEE=$(jq -r --arg EVENT_CREATED "$(echo "$LAST_STATUS_EVENT" | jq -r '.created')" ' + .changelog.histories + | map(select(.created <= $EVENT_CREATED)) + | map(select(any(.items[]?; .field=="assignee"))) + | sort_by(.created) + | last + | .items[]? + | select(.field=="assignee") + | .from // empty + ' "$ISSUE_JSON") + fi + + echo "Current status: ${CUR_STATUS}" + echo "Current assignee: ${CUR_ASSIGNEE:-}" + echo "Closed by: ${LAST_CLOSER:-}" + echo "Prev status: ${PREV_STATUS:-}" + echo "Prev assignee: ${PREV_ASSIGNEE:-}" + + DO_ROLLBACK=false + TARGET_TRANSITION_NAME="" + + if [[ -n "${LAST_CLOSER}" && "${LAST_CLOSER}" == "${MY_JIRA_USERNAME}" ]]; then + DO_ROLLBACK=true + TARGET_TRANSITION_NAME="Reopen Issue" + fi + + echo "Rollback: ${DO_ROLLBACK}" + echo "Transition: ${TARGET_TRANSITION_NAME:-}" + + if [[ "$DRY_RUN" == "false" ]]; then + if [[ "$DO_ROLLBACK" == "true" && -n "$TARGET_TRANSITION_NAME" ]]; then + TRANSITIONS_JSON="$(mktemp)" + api "${JIRA_BASE}/rest/api/2/issue/${ISSUE}/transitions" > "$TRANSITIONS_JSON" + + TRANSITION_ID=$(jq -r --arg NAME "$TARGET_TRANSITION_NAME" ' + .transitions[] | select(.name == $NAME) | .id + ' "$TRANSITIONS_JSON" | head -n1) + + if [[ -n "${TRANSITION_ID}" && "${TRANSITION_ID}" != "null" ]]; then + echo "Applying transition ${TARGET_TRANSITION_NAME} (${TRANSITION_ID})" + api -X POST \ + --data "{\"transition\":{\"id\":\"${TRANSITION_ID}\"}}" \ + "${JIRA_BASE}/rest/api/2/issue/${ISSUE}/transitions" >/dev/null + else + echo "No transition found: ${TARGET_TRANSITION_NAME}" + fi + + rm -f "$TRANSITIONS_JSON" + + if [[ -n "${PREV_ASSIGNEE}" ]]; then + echo "Restoring assignee to ${PREV_ASSIGNEE}" + api -X PUT \ + --data "{\"name\":\"${PREV_ASSIGNEE}\"}" \ + "${JIRA_BASE}/rest/api/2/issue/${ISSUE}/assignee" >/dev/null || true + fi + fi + + COMMENTS_JSON="$(mktemp)" + api "${JIRA_BASE}/rest/api/2/issue/${ISSUE}/comment" > "$COMMENTS_JSON" + + jq -r ' + .comments[] + | select(.body == "CI") + | .id + ' "$COMMENTS_JSON" | while read -r COMMENT_ID; do + echo "Deleting comment ${COMMENT_ID}" + api -X DELETE "${JIRA_BASE}/rest/api/2/issue/${ISSUE}/comment/${COMMENT_ID}" >/dev/null || true + done + + rm -f "$COMMENTS_JSON" + fi + + rm -f "$ISSUE_JSON" +done + +rm -f "$SEARCH_JSON" +echo +echo "Done." diff --git a/openspec/changes/bulk-auth-import-export/design.md b/openspec/changes/bulk-auth-import-export/design.md new file mode 100644 index 00000000..84d32ada --- /dev/null +++ b/openspec/changes/bulk-auth-import-export/design.md @@ -0,0 +1,40 @@ +# Design: bulk-auth-import-export + +## Summary + +Keep the existing `POST /api/accounts/import` route for single-file imports and add `POST /api/accounts/import/batch` for repeated `auth_json` multipart parts. The batch route returns `imported` and `failed` arrays. Each file is processed independently so one malformed or conflicting payload does not discard successful imports in the same request. + +Add `GET /api/accounts/export` that returns a zip archive of current auth payloads built from the persisted account records. Before serializing each payload, attempt to refresh tokens when the stored access token is already expired or the persisted `last_refresh` is beyond the usual refresh threshold. + +## Import Flow + +1. Batch API reads all uploaded `auth_json` files. +2. Service parses each file independently. +3. If the uploaded access token is expired and a refresh token is present, service performs an immediate refresh exchange before saving the account. +4. Service upserts the account using the existing repository conflict policy. +5. Service records either an imported item or a failed item for each uploaded file. + +## Export Flow + +1. Service loads all persisted accounts. +2. For each account, decrypt stored tokens and attempt a best-effort refresh when the access token is expired or due for refresh. +3. Service serializes the current token set back into the existing `auth.json` shape. +4. Service writes one `auth.json` per account into an in-memory zip archive and returns it as a download. + +## API Shape + +- `POST /api/accounts/import` + - request: multipart form with one `auth_json` file part + - response: single imported account result +- `POST /api/accounts/import/batch` + - request: multipart form with one or more `auth_json` file parts + - response: `{ imported: [...], failed: [...] }` +- `GET /api/accounts/export` + - response: `application/zip` + +## Failure Handling + +- Malformed payloads map to `invalid_auth_json` entries in `failed`. +- Import identity conflicts map to `duplicate_identity_conflict` entries in `failed`. +- Import-time refresh failures map to `refresh_failed` entries in `failed`. +- Export is best-effort for token freshness: if a refresh attempt fails, the archive still includes the latest persisted token set for that account. diff --git a/openspec/changes/bulk-auth-import-export/proposal.md b/openspec/changes/bulk-auth-import-export/proposal.md new file mode 100644 index 00000000..b0934a17 --- /dev/null +++ b/openspec/changes/bulk-auth-import-export/proposal.md @@ -0,0 +1,17 @@ +# Proposal: bulk-auth-import-export + +## Why + +Account import currently accepts only one `auth.json` file at a time, leaves already-expired access tokens untouched until a later runtime refresh path happens, and offers no way to export the current account auth payloads from the dashboard. That creates unnecessary operator work and makes backup or migration flows slower and less reliable. + +## What Changes + +- Allow importing multiple `auth.json` files in one dashboard action. +- Refresh imported auth immediately when the uploaded access token is expired but the refresh token can still mint a fresh token set. +- Add a dashboard export endpoint that downloads a zip archive containing one current `auth.json` payload per stored account. + +## Impact + +- Backend accounts API adds a batch import endpoint and auth zip export endpoint. +- Frontend Accounts page import dialog and actions must support selecting multiple files and downloading the export zip. +- Tests must cover partial import success, import-time refresh, and zip archive content. diff --git a/openspec/changes/bulk-auth-import-export/specs/frontend-architecture/spec.md b/openspec/changes/bulk-auth-import-export/specs/frontend-architecture/spec.md new file mode 100644 index 00000000..9826445d --- /dev/null +++ b/openspec/changes/bulk-auth-import-export/specs/frontend-architecture/spec.md @@ -0,0 +1,23 @@ +## MODIFIED Requirements + +### Requirement: Accounts page + +The Accounts page SHALL display a two-column layout: left panel with searchable account list, import button, export button, and add account button; right panel with selected account details including usage, token info, and actions (pause/resume/delete/re-authenticate). + +#### Scenario: Batch account import + +- **WHEN** a user clicks the import button and uploads one or more `auth.json` files +- **THEN** the app calls `POST /api/accounts/import/batch` +- **AND** the response reports imported and failed files independently +- **AND** the account list is refreshed when at least one file imports successfully + +#### Scenario: Import refreshes expired access token + +- **WHEN** an uploaded `auth.json` contains an expired access token +- **AND** the refresh token is still valid +- **THEN** the backend refreshes the token set before persisting the account + +#### Scenario: Export current auth payload archive + +- **WHEN** a user clicks the export button +- **THEN** the app downloads a zip archive containing one current `auth.json` payload per stored account diff --git a/openspec/changes/bulk-auth-import-export/tasks.md b/openspec/changes/bulk-auth-import-export/tasks.md new file mode 100644 index 00000000..1654fd6e --- /dev/null +++ b/openspec/changes/bulk-auth-import-export/tasks.md @@ -0,0 +1,8 @@ +# Tasks: bulk-auth-import-export + +- [x] Update `frontend-architecture` spec delta for batch import and auth zip export behavior. +- [x] Implement backend batch import result schema and route handling. +- [x] Implement import-time token refresh for expired uploaded access tokens. +- [x] Implement backend auth zip export endpoint and payload serialization. +- [x] Update frontend Accounts page import dialog and mutations for multi-file upload and zip export. +- [x] Add backend and frontend tests covering the new flows. diff --git a/openspec/specs/api-keys/spec.md b/openspec/specs/api-keys/spec.md index c9b22aec..c37c7b7e 100644 --- a/openspec/specs/api-keys/spec.md +++ b/openspec/specs/api-keys/spec.md @@ -127,6 +127,32 @@ The dependency SHALL raise a domain exception on validation failure. The excepti - **WHEN** `api_key_auth_enabled` is false - **THEN** the dependency returns `None` and the request proceeds without authentication +### Requirement: Optional additional proxy header guard + +The system SHALL support an optional additional proxy guard controlled by `CODEX_LB_PROXY_KEY_AUTH_ENABLED` and `CODEX_LB_PROXY_KEY`. When enabled, proxy requests MUST provide a matching `X-Codex-Proxy-Key` header in addition to standard Bearer validation (when API key auth is enabled). This header MUST NOT act as an alternative credential path. + +#### Scenario: Optional proxy key guard disabled by default + +- **WHEN** `CODEX_LB_PROXY_KEY_AUTH_ENABLED` is not enabled +- **THEN** requests are evaluated without requiring `X-Codex-Proxy-Key` + +#### Scenario: Optional proxy key guard enabled + +- **WHEN** `CODEX_LB_PROXY_KEY_AUTH_ENABLED=true` and `CODEX_LB_PROXY_KEY` is configured +- **THEN** a request missing `X-Codex-Proxy-Key` is rejected with 401 +- **AND** a request with a non-matching `X-Codex-Proxy-Key` is rejected with 401 +- **AND** a request with a matching `X-Codex-Proxy-Key` continues to normal Bearer API key validation flow + +#### Scenario: Optional proxy key header is not a Bearer substitute + +- **WHEN** `api_key_auth_enabled` is true and a request sends only `X-Codex-Proxy-Key` +- **THEN** the request is rejected with 401 due to missing/invalid Bearer API key + +#### Scenario: Misconfigured enabled guard fails closed + +- **WHEN** `CODEX_LB_PROXY_KEY_AUTH_ENABLED=true` but `CODEX_LB_PROXY_KEY` is missing or blank +- **THEN** guarded proxy requests are rejected with 401 + ### Requirement: Model restriction enforcement The system SHALL enforce per-key model restrictions in the proxy service layer (not middleware). When `allowed_models` is set (non-null, non-empty) and the requested model is not in the list, the system MUST reject the request. The `/v1/models` endpoint MUST filter the model list based on the authenticated key's `allowed_models`. diff --git a/openspec/specs/frontend-architecture/spec.md b/openspec/specs/frontend-architecture/spec.md index ab7ec51d..5e1b7d8a 100644 --- a/openspec/specs/frontend-architecture/spec.md +++ b/openspec/specs/frontend-architecture/spec.md @@ -109,17 +109,24 @@ The Dashboard recent requests table SHALL display each row's recorded request tr ### Requirement: Accounts page -The Accounts page SHALL display a two-column layout: left panel with searchable account list, import button, and add account button; right panel with selected account details including usage, token info, and actions (pause/resume/delete/re-authenticate). +The Accounts page SHALL display a two-column layout: left panel with searchable account list, import button, export button, and add account button; right panel with selected account details including usage, token info, and actions (pause/resume/delete/re-authenticate). #### Scenario: Account selection - **WHEN** a user clicks an account in the list - **THEN** the right panel shows the selected account's details -#### Scenario: Account import +#### Scenario: Batch account import -- **WHEN** a user clicks the import button and uploads an auth.json file -- **THEN** the app calls `POST /api/accounts/import` and refreshes the account list on success +- **WHEN** a user clicks the import button and uploads one or more `auth.json` files +- **THEN** the app calls `POST /api/accounts/import/batch` +- **AND** the response reports imported and failed files independently +- **AND** the account list is refreshed when at least one file imports successfully + +#### Scenario: Export current auth payload archive + +- **WHEN** a user clicks the export button +- **THEN** the app downloads a zip archive containing one current `auth.json` payload per stored account #### Scenario: Ambiguous duplicate identity import conflict diff --git a/scripts/docker-entrypoint.sh b/scripts/docker-entrypoint.sh index 31f66a2f..a09dfdfd 100755 --- a/scripts/docker-entrypoint.sh +++ b/scripts/docker-entrypoint.sh @@ -1,7 +1,13 @@ #!/bin/sh set -eu +if [ -f /app/.env ]; then + set -a + . /app/.env + set +a +fi + python -m app.db.migrate upgrade export CODEX_LB_DATABASE_MIGRATE_ON_STARTUP=false -exec fastapi run --host 0.0.0.0 --port 2455 +exec python -m app.cli --host 0.0.0.0 --port "${PORT:-2455}" diff --git a/tests/conftest.py b/tests/conftest.py index 89e8eec8..ae9160af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ TEST_DB_DIR = Path(tempfile.mkdtemp(prefix="codex-lb-tests-")) TEST_DB_PATH = TEST_DB_DIR / "codex-lb.db" +TEST_DATABASE_URL = os.environ.get("CODEX_LB_TEST_DATABASE_URL", f"sqlite+aiosqlite:///{TEST_DB_PATH}") os.environ["CODEX_LB_DATABASE_URL"] = os.environ.get( "CODEX_LB_TEST_DATABASE_URL", f"sqlite+aiosqlite:///{TEST_DB_PATH}" @@ -22,13 +23,25 @@ os.environ["CODEX_LB_STICKY_SESSION_CLEANUP_ENABLED"] = "false" from app.db.models import Base # noqa: E402 +from app.db.migrate import run_startup_migrations # noqa: E402 from app.db.session import engine # noqa: E402 from app.main import create_app # noqa: E402 -@pytest_asyncio.fixture -async def app_instance(): - app = create_app() +async def _reset_database_via_migrations() -> None: + async with engine.begin() as conn: + + def _reset(sync_conn): + sync_conn.execute(text("DROP TABLE IF EXISTS alembic_version")) + sync_conn.execute(text("DROP TABLE IF EXISTS schema_migrations")) + Base.metadata.drop_all(sync_conn) + + await conn.run_sync(_reset) + + await run_startup_migrations(TEST_DATABASE_URL) + + +async def _reset_database_raw() -> None: async with engine.begin() as conn: def _reset(sync_conn): @@ -38,6 +51,12 @@ def _reset(sync_conn): Base.metadata.create_all(sync_conn) await conn.run_sync(_reset) + + +@pytest_asyncio.fixture +async def app_instance(): + app = create_app() + await _reset_database_via_migrations() return app @@ -49,15 +68,13 @@ async def dispose_engine(): @pytest_asyncio.fixture async def db_setup(): - async with engine.begin() as conn: + await _reset_database_via_migrations() + return True - def _reset(sync_conn): - sync_conn.execute(text("DROP TABLE IF EXISTS alembic_version")) - sync_conn.execute(text("DROP TABLE IF EXISTS schema_migrations")) - Base.metadata.drop_all(sync_conn) - Base.metadata.create_all(sync_conn) - await conn.run_sync(_reset) +@pytest_asyncio.fixture +async def raw_db_setup(): + await _reset_database_raw() return True diff --git a/tests/integration/test_accounts_api_extended.py b/tests/integration/test_accounts_api_extended.py index c210d8aa..4d372167 100644 --- a/tests/integration/test_accounts_api_extended.py +++ b/tests/integration/test_accounts_api_extended.py @@ -1,12 +1,15 @@ from __future__ import annotations import base64 +import io import json +import zipfile from datetime import datetime, timedelta, timezone import pytest from app.core.auth import fallback_account_id, generate_unique_account_id +from app.core.auth.refresh import TokenRefreshResult from app.core.crypto import TokenEncryptor from app.core.utils.time import utcnow from app.db.models import Account, AccountStatus @@ -24,17 +27,27 @@ def _encode_jwt(payload: dict) -> str: return f"header.{body}.sig" -def _make_auth_json(account_id: str | None, email: str, plan_type: str = "plus") -> dict: +def _make_auth_json( + account_id: str | None, + email: str, + plan_type: str = "plus", + *, + access_exp: int | None = None, + refresh_token: str = "refresh", +) -> dict: payload = { "email": email, "https://api.openai.com/auth": {"chatgpt_plan_type": plan_type}, } if account_id: payload["chatgpt_account_id"] = account_id + access_token = "access" + if access_exp is not None: + access_token = _encode_jwt({"exp": access_exp}) tokens: dict[str, object] = { "idToken": _encode_jwt(payload), - "accessToken": "access", - "refreshToken": "refresh", + "accessToken": access_token, + "refreshToken": refresh_token, } if account_id: tokens["accountId"] = account_id @@ -97,6 +110,120 @@ async def test_import_falls_back_to_email_based_account_id(async_client): assert payload["email"] == email +@pytest.mark.asyncio +async def test_batch_import_reports_successes_and_failures(async_client): + valid = _make_auth_json("acc_batch", "batch@example.com") + files = [ + ("auth_json", ("valid.json", json.dumps(valid), "application/json")), + ("auth_json", ("broken.json", "not-json", "application/json")), + ] + + response = await async_client.post("/api/accounts/import/batch", files=files) + + assert response.status_code == 200 + payload = response.json() + assert len(payload["imported"]) == 1 + assert payload["imported"][0]["filename"] == "valid.json" + assert payload["imported"][0]["email"] == "batch@example.com" + assert payload["failed"] == [ + { + "filename": "broken.json", + "code": "invalid_auth_json", + "message": "Invalid auth.json payload", + } + ] + + +@pytest.mark.asyncio +async def test_import_refreshes_expired_access_token_when_refresh_is_valid(async_client, monkeypatch): + old_refresh_token = "refresh-old" + new_access_token = _encode_jwt({"exp": int((utcnow() + timedelta(hours=2)).timestamp())}) + new_id_token = _encode_jwt( + { + "email": "refresh-import@example.com", + "chatgpt_account_id": "acc_refresh_import", + "https://api.openai.com/auth": {"chatgpt_plan_type": "pro"}, + } + ) + + async def _fake_refresh(refresh_token: str) -> TokenRefreshResult: + assert refresh_token == old_refresh_token + return TokenRefreshResult( + access_token=new_access_token, + refresh_token="refresh-new", + id_token=new_id_token, + account_id="acc_refresh_import", + plan_type="pro", + email="refresh-import@example.com", + ) + + monkeypatch.setattr("app.modules.accounts.service.refresh_access_token", _fake_refresh) + + expired_auth = _make_auth_json( + "acc_refresh_import", + "refresh-import@example.com", + "plus", + access_exp=int((utcnow() - timedelta(hours=1)).timestamp()), + refresh_token=old_refresh_token, + ) + + response = await async_client.post( + "/api/accounts/import", + files={"auth_json": ("expired.json", json.dumps(expired_auth), "application/json")}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["refreshedOnImport"] is True + assert payload["planType"] == "pro" + + account_id = generate_unique_account_id("acc_refresh_import", "refresh-import@example.com") + async with SessionLocal() as session: + repo = AccountsRepository(session) + saved = await repo.get_by_id(account_id) + + assert saved is not None + encryptor = TokenEncryptor() + assert encryptor.decrypt(saved.access_token_encrypted) == new_access_token + assert encryptor.decrypt(saved.refresh_token_encrypted) == "refresh-new" + assert encryptor.decrypt(saved.id_token_encrypted) == new_id_token + + +@pytest.mark.asyncio +async def test_export_accounts_downloads_zip_with_current_auth_payloads(async_client): + alpha = _make_auth_json("acc_export_alpha", "alpha@example.com", "plus") + beta = _make_auth_json("acc_export_beta", "beta@example.com", "team") + + first = await async_client.post( + "/api/accounts/import", + files={"auth_json": ("alpha.json", json.dumps(alpha), "application/json")}, + ) + second = await async_client.post( + "/api/accounts/import", + files={"auth_json": ("beta.json", json.dumps(beta), "application/json")}, + ) + assert first.status_code == 200 + assert second.status_code == 200 + + response = await async_client.get("/api/accounts/export") + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/zip" + assert "attachment; filename=" in response.headers["content-disposition"] + + archive = zipfile.ZipFile(io.BytesIO(response.content)) + names = sorted(archive.namelist()) + assert any(name.endswith("/auth.json") for name in names) + assert any("alpha-example.com__" in name for name in names) + assert any("beta-example.com__" in name for name in names) + + exported = [json.loads(archive.read(name).decode("utf-8")) for name in names if name.endswith("/auth.json")] + exported_by_account = {item["tokens"]["accountId"]: item for item in exported} + assert exported_by_account["acc_export_alpha"]["tokens"]["refreshToken"] == "refresh" + assert exported_by_account["acc_export_beta"]["tokens"]["idToken"] + assert exported_by_account["acc_export_alpha"]["lastRefreshAt"] + + @pytest.mark.asyncio async def test_import_overwrites_by_default_for_same_account_identity(async_client): email = "same-default@example.com" diff --git a/tests/integration/test_auth_middleware.py b/tests/integration/test_auth_middleware.py index 6a809e00..3bdb412a 100644 --- a/tests/integration/test_auth_middleware.py +++ b/tests/integration/test_auth_middleware.py @@ -2,6 +2,7 @@ import logging from datetime import timedelta +from types import SimpleNamespace import pytest @@ -153,6 +154,20 @@ async def test_api_key_branch_disabled_then_enabled(async_client): valid = await async_client.get("/v1/models", headers={"Authorization": f"Bearer {created.key}"}) assert valid.status_code == 200 + only_proxy_header = await async_client.get("/v1/models", headers={"X-Codex-Proxy-Key": created.key}) + assert only_proxy_header.status_code == 401 + assert only_proxy_header.json()["error"]["code"] == "invalid_api_key" + + invalid_bearer_with_proxy_header = await async_client.get( + "/v1/models", + headers={ + "Authorization": "Bearer invalid-key", + "X-Codex-Proxy-Key": created.key, + }, + ) + assert invalid_bearer_with_proxy_header.status_code == 401 + assert invalid_bearer_with_proxy_header.json()["error"]["code"] == "invalid_api_key" + async with SessionLocal() as session: repo = ApiKeysRepository(session) row = await repo.get_by_id(created.id) @@ -190,6 +205,97 @@ async def test_api_key_branch_disabled_then_enabled(async_client): assert over_limit.json()["error"]["code"] == "rate_limit_exceeded" +@pytest.mark.asyncio +async def test_optional_proxy_key_header_is_additional_guard(async_client, monkeypatch): + enable = await async_client.put( + "/api/settings", + json={ + "stickyThreadsEnabled": False, + "preferEarlierResetAccounts": False, + "totpRequiredOnLogin": False, + "apiKeyAuthEnabled": True, + }, + ) + assert enable.status_code == 200 + + async with SessionLocal() as session: + service = ApiKeysService(ApiKeysRepository(session)) + created = await service.create_key( + ApiKeyCreateData( + name="proxy-key-guard", + allowed_models=None, + expires_at=None, + ) + ) + + monkeypatch.setattr( + "app.core.auth.dependencies.get_settings", + lambda: SimpleNamespace(proxy_key_auth_enabled=True, proxy_key="proxy-shared-secret"), + ) + + missing_proxy_key = await async_client.get("/v1/models", headers={"Authorization": f"Bearer {created.key}"}) + assert missing_proxy_key.status_code == 401 + assert missing_proxy_key.json()["error"]["code"] == "invalid_api_key" + + invalid_proxy_key = await async_client.get( + "/v1/models", + headers={ + "Authorization": f"Bearer {created.key}", + "X-Codex-Proxy-Key": "wrong-secret", + }, + ) + assert invalid_proxy_key.status_code == 401 + assert invalid_proxy_key.json()["error"]["code"] == "invalid_api_key" + + valid_proxy_key = await async_client.get( + "/v1/models", + headers={ + "Authorization": f"Bearer {created.key}", + "X-Codex-Proxy-Key": "proxy-shared-secret", + }, + ) + assert valid_proxy_key.status_code == 200 + + +@pytest.mark.asyncio +async def test_optional_proxy_key_header_enabled_without_config_fails(async_client, monkeypatch): + enable = await async_client.put( + "/api/settings", + json={ + "stickyThreadsEnabled": False, + "preferEarlierResetAccounts": False, + "totpRequiredOnLogin": False, + "apiKeyAuthEnabled": True, + }, + ) + assert enable.status_code == 200 + + async with SessionLocal() as session: + service = ApiKeysService(ApiKeysRepository(session)) + created = await service.create_key( + ApiKeyCreateData( + name="proxy-key-misconfig", + allowed_models=None, + expires_at=None, + ) + ) + + monkeypatch.setattr( + "app.core.auth.dependencies.get_settings", + lambda: SimpleNamespace(proxy_key_auth_enabled=True, proxy_key=None), + ) + + misconfigured = await async_client.get( + "/v1/models", + headers={ + "Authorization": f"Bearer {created.key}", + "X-Codex-Proxy-Key": "any", + }, + ) + assert misconfigured.status_code == 401 + assert misconfigured.json()["error"]["code"] == "invalid_api_key" + + @pytest.mark.asyncio async def test_codex_usage_does_not_allow_dashboard_session_without_caller_identity(async_client): setup = await async_client.post( diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 62598c52..48db16a1 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -41,6 +41,34 @@ _STAMPED_AFTER_LEGACY_PREFIX_1 = OLD_TO_NEW_REVISION_MAP["001_normalize_account_plan_types"] +def _expected_head_revisions() -> list[str]: + return sorted(revision for revision in _HEAD_REVISION.split(",") if revision) + + +def _normalize_revision_rows(raw_revisions: list[str]) -> list[str]: + revisions: set[str] = set() + for raw in raw_revisions: + revisions.update(part for part in str(raw).split(",") if part) + return sorted(revisions) + + +async def _read_alembic_version_revisions() -> list[str]: + async with SessionLocal() as session: + revision_rows = await session.execute(text("SELECT version_num FROM alembic_version")) + raw = [str(row[0]) for row in revision_rows.fetchall()] + return _normalize_revision_rows(raw) + + +async def _replace_alembic_versions_with_single_revision(revision: str) -> None: + async with SessionLocal() as session: + await session.execute(text("DELETE FROM alembic_version")) + await session.execute( + text("INSERT INTO alembic_version (version_num) VALUES (:revision)"), + {"revision": revision}, + ) + await session.commit() + + def _is_postgresql_database_url(url: str) -> bool: return url.startswith("postgresql+") @@ -61,7 +89,7 @@ def _make_account(account_id: str, email: str, plan_type: str) -> Account: @pytest.mark.asyncio -async def test_run_startup_migrations_preserves_unknown_plan_types(db_setup): +async def test_run_startup_migrations_preserves_unknown_plan_types(raw_db_setup): async with SessionLocal() as session: repo = AccountsRepository(session) await repo.upsert(_make_account("acc_one", "one@example.com", "education")) @@ -88,7 +116,7 @@ async def test_run_startup_migrations_preserves_unknown_plan_types(db_setup): @pytest.mark.asyncio -async def test_run_startup_migrations_bootstraps_legacy_history(db_setup): +async def test_run_startup_migrations_bootstraps_legacy_history(raw_db_setup): async with SessionLocal() as session: await session.execute( text( @@ -112,14 +140,12 @@ async def test_run_startup_migrations_bootstraps_legacy_history(db_setup): assert result.bootstrap.stamped_revision == _STAMPED_AFTER_LEGACY_PREFIX_4 assert result.current_revision == _HEAD_REVISION - async with SessionLocal() as session: - revision_rows = await session.execute(text("SELECT version_num FROM alembic_version")) - revisions = [str(row[0]) for row in revision_rows.fetchall()] - assert revisions == [_HEAD_REVISION] + revisions = await _read_alembic_version_revisions() + assert revisions == _expected_head_revisions() @pytest.mark.asyncio -async def test_run_startup_migrations_skips_legacy_stamp_when_required_tables_missing(db_setup): +async def test_run_startup_migrations_skips_legacy_stamp_when_required_tables_missing(raw_db_setup): async with SessionLocal() as session: await session.execute(text("DROP TABLE dashboard_settings")) await session.execute( @@ -150,7 +176,7 @@ async def test_run_startup_migrations_skips_legacy_stamp_when_required_tables_mi @pytest.mark.asyncio -async def test_run_startup_migrations_handles_unknown_legacy_rows(db_setup): +async def test_run_startup_migrations_handles_unknown_legacy_rows(raw_db_setup): async with SessionLocal() as session: await session.execute( text( @@ -181,48 +207,37 @@ async def test_run_startup_migrations_handles_unknown_legacy_rows(db_setup): @pytest.mark.asyncio @pytest.mark.skipif(not _HAS_REVISION_REMAP, reason="requires revision remap support") -async def test_run_startup_migrations_auto_remaps_legacy_alembic_revision_ids(db_setup): +async def test_run_startup_migrations_auto_remaps_legacy_alembic_revision_ids(raw_db_setup): await run_startup_migrations(_DATABASE_URL) legacy_head = "013_add_dashboard_settings_routing_strategy" - async with SessionLocal() as session: - await session.execute(text("UPDATE alembic_version SET version_num = :legacy"), {"legacy": legacy_head}) - await session.commit() + await _replace_alembic_versions_with_single_revision(legacy_head) result = await run_startup_migrations(_DATABASE_URL) assert result.current_revision == _HEAD_REVISION - async with SessionLocal() as session: - revision_rows = await session.execute(text("SELECT version_num FROM alembic_version")) - revisions = sorted(str(row[0]) for row in revision_rows.fetchall()) - assert revisions == [_HEAD_REVISION] + revisions = await _read_alembic_version_revisions() + assert revisions == _expected_head_revisions() @pytest.mark.asyncio @pytest.mark.skipif(not _HAS_REVISION_REMAP, reason="requires revision remap support") -async def test_run_startup_migrations_auto_remaps_firewall_legacy_revision_id(db_setup): +async def test_run_startup_migrations_auto_remaps_firewall_legacy_revision_id(raw_db_setup): await run_startup_migrations(_DATABASE_URL) legacy_firewall_revision = "014_add_api_firewall_allowlist" - async with SessionLocal() as session: - await session.execute( - text("UPDATE alembic_version SET version_num = :legacy"), - {"legacy": legacy_firewall_revision}, - ) - await session.commit() + await _replace_alembic_versions_with_single_revision(legacy_firewall_revision) result = await run_startup_migrations(_DATABASE_URL) assert result.current_revision == _HEAD_REVISION - async with SessionLocal() as session: - revision_rows = await session.execute(text("SELECT version_num FROM alembic_version")) - revisions = sorted(str(row[0]) for row in revision_rows.fetchall()) - assert revisions == [_HEAD_REVISION] + revisions = await _read_alembic_version_revisions() + assert revisions == _expected_head_revisions() @pytest.mark.asyncio @pytest.mark.skipif(not _HAS_REVISION_REMAP, reason="requires revision remap support") -async def test_run_startup_migrations_handles_legacy_schema_table_and_legacy_alembic_id_together(db_setup): +async def test_run_startup_migrations_handles_legacy_schema_table_and_legacy_alembic_id_together(raw_db_setup): await run_startup_migrations(_DATABASE_URL) async with SessionLocal() as session: @@ -241,8 +256,9 @@ async def test_run_startup_migrations_handles_legacy_schema_table_and_legacy_ale text("INSERT INTO schema_migrations (name, applied_at) VALUES (:name, :applied_at)"), {"name": migration_name, "applied_at": f"2026-02-13T00:00:0{index}Z"}, ) + await session.execute(text("DELETE FROM alembic_version")) await session.execute( - text("UPDATE alembic_version SET version_num = :legacy"), + text("INSERT INTO alembic_version (version_num) VALUES (:legacy)"), {"legacy": "013_add_dashboard_settings_routing_strategy"}, ) await session.commit() @@ -257,7 +273,7 @@ async def test_run_startup_migrations_handles_legacy_schema_table_and_legacy_ale (not _is_postgresql_database_url(_DATABASE_URL)) or check_migration_policy is None, reason="PostgreSQL-only migration contract test", ) -async def test_postgresql_migration_contract_policy_and_drift_match(db_setup): +async def test_postgresql_migration_contract_policy_and_drift_match(raw_db_setup): result = await run_startup_migrations(_DATABASE_URL) assert result.current_revision == _HEAD_REVISION @@ -290,7 +306,7 @@ async def test_postgresql_upgrade_head_from_empty_database(db_setup): (not _is_postgresql_database_url(_DATABASE_URL)) or (not _HAS_REVISION_REMAP), reason="PostgreSQL-only migration remap test", ) -async def test_postgresql_startup_migration_auto_remap_legacy_head(db_setup): +async def test_postgresql_startup_migration_auto_remap_legacy_head(raw_db_setup): await run_startup_migrations(_DATABASE_URL) async with SessionLocal() as session: @@ -485,82 +501,86 @@ async def test_run_startup_migrations_drops_accounts_email_unique_with_non_casca await session.execute(text("SELECT routing_strategy FROM dashboard_settings WHERE id=1")) ).scalar_one() assert routing_strategy == "usage_weighted" - assert "openai_cache_affinity_max_age_seconds" in dashboard_columns - affinity_ttl = ( - await session.execute( - text("SELECT openai_cache_affinity_max_age_seconds FROM dashboard_settings WHERE id=1") - ) - ).scalar_one() - assert affinity_ttl == 300 - sticky_columns_rows = (await session.execute(text("PRAGMA table_info(sticky_sessions)"))).fetchall() - sticky_columns = {str(row[1]) for row in sticky_columns_rows if len(row) > 1} - assert "kind" in sticky_columns - sticky_kind = ( - await session.execute(text("SELECT kind FROM sticky_sessions WHERE key='sticky_1'")) + http_proxy_url = ( + await session.execute(text("SELECT http_proxy_url FROM dashboard_settings WHERE id=1")) ).scalar_one() - assert sticky_kind == "sticky_thread" + assert http_proxy_url is None + assert "openai_cache_affinity_max_age_seconds" in dashboard_columns + affinity_ttl = ( await session.execute( - text( - """ - INSERT INTO sticky_sessions (key, account_id, kind, created_at, updated_at) - VALUES ('sticky_1', 'acc_legacy', 'prompt_cache', '2026-01-01 00:00:00', '2026-01-01 00:00:00') - """ - ) + text("SELECT openai_cache_affinity_max_age_seconds FROM dashboard_settings WHERE id=1") ) - sticky_same_key_count = ( - await session.execute(text("SELECT COUNT(*) FROM sticky_sessions WHERE key='sticky_1'")) - ).scalar_one() - assert sticky_same_key_count == 2 - index_rows = (await session.execute(text("PRAGMA index_list(accounts)"))).fetchall() - has_email_non_unique_index = False - for row in index_rows: - if len(row) < 3: - continue - index_name = str(row[1]) - is_unique = bool(row[2]) - escaped_name = index_name.replace('"', '""') - index_info_rows = (await session.execute(text(f'PRAGMA index_info("{escaped_name}")'))).fetchall() - column_names = [str(info[2]) for info in index_info_rows if len(info) > 2] - if column_names == ["email"] and not is_unique: - has_email_non_unique_index = True - break - assert has_email_non_unique_index - usage_index_rows = (await session.execute(text("PRAGMA index_list(usage_history)"))).fetchall() - usage_index_names = {str(row[1]) for row in usage_index_rows if len(row) > 1} - assert "idx_usage_window_account_latest" in usage_index_names - request_log_index_rows = (await session.execute(text("PRAGMA index_list(request_logs)"))).fetchall() - request_log_index_names = {str(row[1]) for row in request_log_index_rows if len(row) > 1} - assert "idx_logs_requested_at_id" in request_log_index_names + ).scalar_one() + assert affinity_ttl == 300 + sticky_columns_rows = (await session.execute(text("PRAGMA table_info(sticky_sessions)"))).fetchall() + sticky_columns = {str(row[1]) for row in sticky_columns_rows if len(row) > 1} + assert "kind" in sticky_columns + sticky_kind = ( + await session.execute(text("SELECT kind FROM sticky_sessions WHERE key='sticky_1'")) + ).scalar_one() + assert sticky_kind == "sticky_thread" + await session.execute( + text( + """ + INSERT INTO sticky_sessions (key, account_id, kind, created_at, updated_at) + VALUES ('sticky_1', 'acc_legacy', 'prompt_cache', '2026-01-01 00:00:00', '2026-01-01 00:00:00') + """ + ) + ) + sticky_same_key_count = ( + await session.execute(text("SELECT COUNT(*) FROM sticky_sessions WHERE key='sticky_1'")) + ).scalar_one() + assert sticky_same_key_count == 2 + index_rows = (await session.execute(text("PRAGMA index_list(accounts)"))).fetchall() + has_email_non_unique_index = False + for row in index_rows: + if len(row) < 3: + continue + index_name = str(row[1]) + is_unique = bool(row[2]) + escaped_name = index_name.replace('"', '""') + index_info_rows = (await session.execute(text(f'PRAGMA index_info("{escaped_name}")'))).fetchall() + column_names = [str(info[2]) for info in index_info_rows if len(info) > 2] + if column_names == ["email"] and not is_unique: + has_email_non_unique_index = True + break + assert has_email_non_unique_index + usage_index_rows = (await session.execute(text("PRAGMA index_list(usage_history)"))).fetchall() + usage_index_names = {str(row[1]) for row in usage_index_rows if len(row) > 1} + assert "idx_usage_window_account_latest" in usage_index_names + request_log_index_rows = (await session.execute(text("PRAGMA index_list(request_logs)"))).fetchall() + request_log_index_names = {str(row[1]) for row in request_log_index_rows if len(row) > 1} + assert "idx_logs_requested_at_id" in request_log_index_names - await session.execute( - text( - """ - INSERT INTO accounts ( - id, chatgpt_account_id, email, plan_type, - access_token_encrypted, refresh_token_encrypted, id_token_encrypted, - last_refresh, created_at, status, deactivation_reason, reset_at - ) - VALUES ( - 'acc_legacy_2', 'chatgpt_legacy_2', 'legacy@example.com', 'team', - x'11', x'12', x'13', - '2026-01-01 00:00:00', '2026-01-01 00:00:00', 'active', NULL, NULL - ) - """ + await session.execute( + text( + """ + INSERT INTO accounts ( + id, chatgpt_account_id, email, plan_type, + access_token_encrypted, refresh_token_encrypted, id_token_encrypted, + last_refresh, created_at, status, deactivation_reason, reset_at ) + VALUES ( + 'acc_legacy_2', 'chatgpt_legacy_2', 'legacy@example.com', 'team', + x'11', x'12', x'13', + '2026-01-01 00:00:00', '2026-01-01 00:00:00', 'active', NULL, NULL + ) + """ ) - usage_count = ( - await session.execute(text("SELECT COUNT(*) FROM usage_history WHERE account_id='acc_legacy'")) - ).scalar_one() - logs_count = ( - await session.execute(text("SELECT COUNT(*) FROM request_logs WHERE account_id='acc_legacy'")) - ).scalar_one() - sticky_count = ( - await session.execute(text("SELECT COUNT(*) FROM sticky_sessions WHERE account_id='acc_legacy'")) - ).scalar_one() - await session.commit() + ) + usage_count = ( + await session.execute(text("SELECT COUNT(*) FROM usage_history WHERE account_id='acc_legacy'")) + ).scalar_one() + logs_count = ( + await session.execute(text("SELECT COUNT(*) FROM request_logs WHERE account_id='acc_legacy'")) + ).scalar_one() + sticky_count = ( + await session.execute(text("SELECT COUNT(*) FROM sticky_sessions WHERE account_id='acc_legacy'")) + ).scalar_one() + await session.commit() - assert usage_count == 1 - assert logs_count == 1 - assert sticky_count == 2 + assert usage_count == 1 + assert logs_count == 1 + assert sticky_count == 2 finally: await engine.dispose() diff --git a/tests/integration/test_settings_api.py b/tests/integration/test_settings_api.py index 630a9131..6a237ed0 100644 --- a/tests/integration/test_settings_api.py +++ b/tests/integration/test_settings_api.py @@ -16,6 +16,7 @@ async def test_settings_api_get_and_update(async_client): assert payload["routingStrategy"] == "usage_weighted" assert payload["openaiCacheAffinityMaxAgeSeconds"] == 300 assert payload["importWithoutOverwrite"] is False + assert payload["httpProxyUrl"] is None assert payload["totpRequiredOnLogin"] is False assert payload["totpConfigured"] is False assert payload["apiKeyAuthEnabled"] is False @@ -29,6 +30,7 @@ async def test_settings_api_get_and_update(async_client): "routingStrategy": "round_robin", "openaiCacheAffinityMaxAgeSeconds": 180, "importWithoutOverwrite": True, + "httpProxyUrl": "http://proxy.internal:8080", "totpRequiredOnLogin": False, "apiKeyAuthEnabled": True, }, @@ -41,6 +43,7 @@ async def test_settings_api_get_and_update(async_client): assert updated["routingStrategy"] == "round_robin" assert updated["openaiCacheAffinityMaxAgeSeconds"] == 180 assert updated["importWithoutOverwrite"] is True + assert updated["httpProxyUrl"] == "http://proxy.internal:8080" assert updated["totpRequiredOnLogin"] is False assert updated["totpConfigured"] is False assert updated["apiKeyAuthEnabled"] is True @@ -54,6 +57,7 @@ async def test_settings_api_get_and_update(async_client): assert payload["routingStrategy"] == "round_robin" assert payload["openaiCacheAffinityMaxAgeSeconds"] == 180 assert payload["importWithoutOverwrite"] is True + assert payload["httpProxyUrl"] == "http://proxy.internal:8080" assert payload["totpRequiredOnLogin"] is False assert payload["totpConfigured"] is False assert payload["apiKeyAuthEnabled"] is True diff --git a/tests/unit/test_auth_manager.py b/tests/unit/test_auth_manager.py index 78cd73f1..9fa8ec3b 100644 --- a/tests/unit/test_auth_manager.py +++ b/tests/unit/test_auth_manager.py @@ -18,6 +18,7 @@ class _DummyRepo: def __init__(self) -> None: self.tokens_payload: dict[str, object] | None = None + self.status_payload: dict[str, object] | None = None async def update_status( self, @@ -25,6 +26,11 @@ async def update_status( status: AccountStatus, deactivation_reason: str | None = None, ) -> bool: + self.status_payload = { + "account_id": account_id, + "status": status, + "deactivation_reason": deactivation_reason, + } return True async def update_tokens( @@ -85,3 +91,39 @@ async def _fake_refresh(_: str) -> TokenRefreshResult: assert updated.plan_type == "pro" assert repo.tokens_payload is not None assert repo.tokens_payload["plan_type"] == "pro" + + +@pytest.mark.asyncio +async def test_refresh_account_deactivates_on_permanent_refresh_error(monkeypatch): + async def _fake_refresh(_: str) -> TokenRefreshResult: + raise auth_manager_module.RefreshError( + code="invalid_grant", + message="Your refresh token has already been used to generate a new access token. Please try signing in again.", + is_permanent=True, + ) + + monkeypatch.setattr(auth_manager_module, "refresh_access_token", _fake_refresh) + + encryptor = TokenEncryptor() + account = Account( + id="acc_perm", + email="user@example.com", + plan_type="pro", + access_token_encrypted=encryptor.encrypt("access-old"), + refresh_token_encrypted=encryptor.encrypt("refresh-old"), + id_token_encrypted=encryptor.encrypt("id-old"), + last_refresh=utcnow(), + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + repo = _DummyRepo() + manager = AuthManager(repo) + + with pytest.raises(auth_manager_module.RefreshError): + await manager.refresh_account(account) + + assert account.status == AccountStatus.DEACTIVATED + assert repo.status_payload is not None + assert repo.status_payload["account_id"] == "acc_perm" + assert repo.status_payload["status"] == AccountStatus.DEACTIVATED + assert account.deactivation_reason == "Refresh token was reused - re-login required" diff --git a/tests/unit/test_auth_refresh.py b/tests/unit/test_auth_refresh.py index 294911bc..4e62b639 100644 --- a/tests/unit/test_auth_refresh.py +++ b/tests/unit/test_auth_refresh.py @@ -4,7 +4,7 @@ import pytest -from app.core.auth.refresh import classify_refresh_error, should_refresh +from app.core.auth.refresh import _is_permanent_refresh_failure, classify_refresh_error, should_refresh from app.core.utils.time import utcnow pytestmark = pytest.mark.unit @@ -26,3 +26,19 @@ def test_classify_refresh_error_permanent(): def test_classify_refresh_error_temporary(): assert classify_refresh_error("temporary_error") is False + + +def test_refresh_failure_401_invalid_grant_is_permanent(): + assert _is_permanent_refresh_failure( + "invalid_grant", + "Your refresh token has already been used to generate a new access token. Please try signing in again.", + 401, + ) is True + + +def test_refresh_failure_non_401_invalid_grant_is_not_forced_permanent(): + assert _is_permanent_refresh_failure( + "invalid_grant", + "temporary upstream failure", + 500, + ) is False diff --git a/tests/unit/test_auth_refresh_http.py b/tests/unit/test_auth_refresh_http.py new file mode 100644 index 00000000..e6efce40 --- /dev/null +++ b/tests/unit/test_auth_refresh_http.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import pytest + +from app.core.auth import refresh as refresh_module +from app.core.auth.refresh import ( + DEFAULT_REFRESH_TOKEN_URL, + REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, + refresh_access_token, + refresh_token_endpoint, +) + +pytestmark = pytest.mark.unit + + +class _FakeResponse: + def __init__(self, *, status: int, payload: dict[str, object]) -> None: + self.status = status + self._payload = payload + + async def __aenter__(self) -> _FakeResponse: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + async def json(self, content_type=None): + return self._payload + + async def text(self) -> str: + return str(self._payload) + + +class _FakeSession: + def __init__(self, response: _FakeResponse) -> None: + self._response = response + self.captured_url: str | None = None + self.captured_json: dict[str, object] | None = None + self.captured_proxy: str | None = None + + def post(self, url: str, *, json, headers, timeout, proxy=None): # noqa: ANN001 + self.captured_url = url + self.captured_json = json + self.captured_proxy = proxy + return self._response + + +class _FakeSettingsRow: + http_proxy_url = "http://dashboard.proxy:3128" + + +class _FakeSettingsCache: + async def get(self) -> _FakeSettingsRow: + return _FakeSettingsRow() + + +def test_refresh_token_endpoint_uses_default_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, raising=False) + assert refresh_token_endpoint() == DEFAULT_REFRESH_TOKEN_URL + + +def test_refresh_token_endpoint_uses_override(monkeypatch: pytest.MonkeyPatch) -> None: + override_url = "https://example.test/custom/token" + monkeypatch.setenv(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, f" {override_url} ") + assert refresh_token_endpoint() == override_url + + +@pytest.mark.asyncio +async def test_refresh_access_token_posts_refresh_payload(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, raising=False) + refresh_module.get_settings.cache_clear() + fake = _FakeSession( + _FakeResponse( + status=200, + payload={ + "access_token": "new-access", + "refresh_token": "new-refresh", + "id_token": "new-id-token", + }, + ) + ) + + result = await refresh_access_token("old-refresh-token", session=fake) + + assert result.access_token == "new-access" + assert result.refresh_token == "new-refresh" + assert result.id_token == "new-id-token" + assert fake.captured_url == DEFAULT_REFRESH_TOKEN_URL + assert fake.captured_json is not None + assert fake.captured_json["client_id"] == "app_EMoamEEZ73f0CkXaXp7hrann" + assert fake.captured_json["grant_type"] == "refresh_token" + assert fake.captured_json["refresh_token"] == "old-refresh-token" + assert fake.captured_json["scope"] == "openid profile email" + assert fake.captured_proxy is None + + +@pytest.mark.asyncio +async def test_refresh_access_token_uses_env_http_proxy(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CODEX_LB_HTTP_PROXY_URL", "http://env.proxy:8080") + refresh_module.get_settings.cache_clear() + fake = _FakeSession( + _FakeResponse( + status=200, + payload={ + "access_token": "new-access", + "refresh_token": "new-refresh", + "id_token": "new-id-token", + }, + ) + ) + + await refresh_access_token("old-refresh-token", session=fake) + + assert fake.captured_proxy == "http://env.proxy:8080" + + +@pytest.mark.asyncio +async def test_refresh_access_token_uses_dashboard_http_proxy_when_env_absent(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("CODEX_LB_HTTP_PROXY_URL", raising=False) + refresh_module.get_settings.cache_clear() + monkeypatch.setattr("app.core.clients.http.get_settings_cache", lambda: _FakeSettingsCache()) + fake = _FakeSession( + _FakeResponse( + status=200, + payload={ + "access_token": "new-access", + "refresh_token": "new-refresh", + "id_token": "new-id-token", + }, + ) + ) + + await refresh_access_token("old-refresh-token", session=fake) + + assert fake.captured_proxy == "http://dashboard.proxy:3128" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index c281f2a4..27037690 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -21,16 +21,19 @@ def fake_run(*args, **kwargs): monkeypatch.setattr(sys, "argv", ["codex-lb"]) monkeypatch.setattr(cli.uvicorn, "run", fake_run) + monkeypatch.setenv("CODEX_LB_LOG_LEVEL", "debug") cli.main() + args = captured["args"] + assert args[0] == "app.main:app" kwargs = captured["kwargs"] assert isinstance(kwargs, dict) log_config = kwargs["log_config"] assert isinstance(log_config, dict) formatters = log_config["formatters"] - assert formatters["default"]["fmt"].startswith("%(asctime)s ") - assert formatters["access"]["fmt"].startswith("%(asctime)s ") + assert formatters["standard"]["format"].startswith("%(asctime)s ") + assert log_config["loggers"]["uvicorn.access"]["level"] == "WARNING" def test_utc_default_formatter_formats_without_converter_binding_error(): diff --git a/tests/unit/test_logging_config.py b/tests/unit/test_logging_config.py new file mode 100644 index 00000000..795eed42 --- /dev/null +++ b/tests/unit/test_logging_config.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import logging + +import pytest + +from app.core.logging import build_log_config, configure_logging, resolve_log_level + +pytestmark = pytest.mark.unit + + +def test_resolve_log_level_falls_back_to_info_for_invalid_values() -> None: + assert resolve_log_level("debug") == "DEBUG" + assert resolve_log_level(" noisy ") == "INFO" + assert resolve_log_level(None) == "INFO" + + +def test_build_log_config_disables_uvicorn_access_logs_and_uses_stdout_handler() -> None: + config = build_log_config("debug") + + assert config["root"]["level"] == "DEBUG" + assert config["handlers"]["default"]["stream"] == "ext://sys.stdout" + assert config["loggers"]["uvicorn.access"] == { + "handlers": [], + "level": "WARNING", + "propagate": False, + } + + +def test_configure_logging_applies_root_level(monkeypatch) -> None: + monkeypatch.setenv("CODEX_LB_LOG_LEVEL", "debug") + + resolved = configure_logging() + + assert resolved == "DEBUG" + assert logging.getLogger().getEffectiveLevel() == logging.DEBUG diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index bbdab539..27a57766 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -902,7 +902,7 @@ async def fake_iter(resp, idle_timeout_seconds, max_event_bytes): assert events == ['data: {"type":"response.completed","response":{"id":"resp_1"}}\n\n'] timeout = session.calls[0]["timeout"] assert isinstance(timeout, proxy_module.aiohttp.ClientTimeout) - assert timeout.total == pytest.approx(4.5, abs=0.01) + assert timeout.total == pytest.approx(4.5, abs=0.05) assert timeout.sock_connect == pytest.approx(2.5) assert seen["idle_timeout_seconds"] == pytest.approx(3.5) diff --git a/tests/unit/test_settings_firewall.py b/tests/unit/test_settings_firewall.py index 229bfb42..36fef85a 100644 --- a/tests/unit/test_settings_firewall.py +++ b/tests/unit/test_settings_firewall.py @@ -18,3 +18,17 @@ def test_settings_rejects_invalid_firewall_trusted_proxy_cidr(monkeypatch): monkeypatch.setenv("CODEX_LB_FIREWALL_TRUSTED_PROXY_CIDRS", "not-a-cidr") with pytest.raises(ValidationError): Settings() + + +def test_settings_parses_optional_proxy_key(monkeypatch): + monkeypatch.setenv("CODEX_LB_PROXY_KEY_AUTH_ENABLED", "true") + monkeypatch.setenv("CODEX_LB_PROXY_KEY", " shared-secret ") + settings = Settings() + assert settings.proxy_key_auth_enabled is True + assert settings.proxy_key == "shared-secret" + + +def test_settings_normalizes_empty_optional_proxy_key(monkeypatch): + monkeypatch.setenv("CODEX_LB_PROXY_KEY", " ") + settings = Settings() + assert settings.proxy_key is None diff --git a/tests/unit/test_usage_client.py b/tests/unit/test_usage_client.py index 50fccafb..52913ea3 100644 --- a/tests/unit/test_usage_client.py +++ b/tests/unit/test_usage_client.py @@ -29,6 +29,7 @@ class UsageClientState: calls: int = 0 auth: str | None = None account: str | None = None + proxy: str | None = None class StubRequestContext: @@ -77,7 +78,9 @@ def request( headers: dict[str, str] | None = None, timeout: object | None = None, retry_options: object | None = None, + proxy: str | None = None, ) -> StubRequestContext: + self._state.proxy = proxy return StubRequestContext(self._responses, self._state, headers or {}, retry_options) @@ -146,3 +149,25 @@ async def test_fetch_usage_raises_after_retries(failing_usage_server): exc = excinfo.value assert isinstance(exc, UsageFetchError) assert exc.status_code == 503 + + +@pytest.mark.asyncio +async def test_fetch_usage_sanitizes_html_error_body() -> None: + state = UsageClientState() + client = StubRetryClient( + [StubResponse(403, None, "forbidden")], + state, + ) + + with pytest.raises(UsageFetchError) as excinfo: + await fetch_usage( + access_token="access-token", + account_id=None, + base_url="http://usage.test/backend-api", + max_retries=0, + timeout_seconds=1.0, + client=client, + ) + + assert excinfo.value.status_code == 403 + assert excinfo.value.message == "Upstream returned an HTML error response" diff --git a/tests/unit/test_usage_updater.py b/tests/unit/test_usage_updater.py index 3e2ce0fc..1a41a41a 100644 --- a/tests/unit/test_usage_updater.py +++ b/tests/unit/test_usage_updater.py @@ -347,7 +347,7 @@ async def update_tokens(self, *args: Any, **kwargs: Any) -> bool: @pytest.mark.asyncio -async def test_usage_updater_deactivates_on_account_invalid_4xx(monkeypatch) -> None: +async def test_usage_updater_does_not_deactivate_on_account_invalid_4xx(monkeypatch) -> None: monkeypatch.setenv("CODEX_LB_USAGE_REFRESH_ENABLED", "true") from app.core.clients.usage import UsageFetchError from app.core.config.settings import get_settings @@ -368,12 +368,7 @@ async def stub_fetch_usage_402(**_: Any) -> UsagePayload: await updater.refresh_accounts([acc], latest_usage={}) - assert len(accounts_repo.status_updates) == 1 - update = accounts_repo.status_updates[0] - assert update["account_id"] == "acc_402" - assert update["status"] == AccountStatus.DEACTIVATED - assert "402" in update["deactivation_reason"] - assert "Payment Required" in update["deactivation_reason"] + assert len(accounts_repo.status_updates) == 0 @pytest.mark.asyncio @@ -652,6 +647,43 @@ async def stub_ensure_fresh(account: Account, *, force: bool = False) -> Account assert len(usage_repo.entries) == 0 +@pytest.mark.asyncio +async def test_usage_updater_deactivates_account_when_401_refresh_fails_permanently(monkeypatch) -> None: + monkeypatch.setenv("CODEX_LB_USAGE_REFRESH_ENABLED", "true") + from app.core.clients.usage import UsageFetchError + from app.core.config.settings import get_settings + + get_settings.cache_clear() + + async def stub_fetch_usage_401(**_: Any) -> UsagePayload: + raise UsageFetchError(401, "Provided authentication token is expired. Please try signing in again.") + + monkeypatch.setattr("app.modules.usage.updater.fetch_usage", stub_fetch_usage_401) + + usage_repo = StubUsageRepository(return_rows=True) + accounts_repo = StubAccountsRepository() + updater = UsageUpdater(usage_repo, accounts_repo=accounts_repo) + assert updater._auth_manager is not None + + async def stub_ensure_fresh(account: Account, *, force: bool = False) -> Account: + account.status = AccountStatus.DEACTIVATED + account.deactivation_reason = "Refresh token was reused - re-login required" + raise RefreshError( + code="refresh_token_reused", + message="Your refresh token has already been used to generate a new access token. Please try signing in again.", + is_permanent=True, + ) + + monkeypatch.setattr(updater._auth_manager, "ensure_fresh", stub_ensure_fresh) + + acc = _make_account("acc_401_perm", "workspace_401_perm", email="auth-perm@example.com") + refreshed = await updater.refresh_accounts([acc], latest_usage={}) + + assert refreshed is False + assert acc.status == AccountStatus.DEACTIVATED + assert len(usage_repo.entries) == 0 + + @pytest.mark.parametrize( ("primary_used", "secondary_used"), [ diff --git a/uv.lock b/uv.lock index 8ec940dc..90783304 100644 --- a/uv.lock +++ b/uv.lock @@ -368,7 +368,7 @@ wheels = [ [[package]] name = "codex-lb" -version = "1.4.1" +version = "1.5.2" source = { editable = "." } dependencies = [ { name = "aiohttp" },