diff --git a/.env.example b/.env.example index 4379e519..67250221 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,40 @@ CODEX_LB_DATABASE_SQLITE_PRE_MIGRATE_BACKUP_MAX_FILES=5 # Upstream ChatGPT base URL (no /codex suffix) CODEX_LB_UPSTREAM_BASE_URL=https://chatgpt.com/backend-api +# Anthropic SDK runtime +# Optional explicit Claude CLI path when auto-discovery via PATH is not enough +# CODEX_LB_ANTHROPIC_SDK_CLI_PATH=/usr/local/bin/claude +# Optional default session ID for SDK conversation continuity +# CODEX_LB_ANTHROPIC_SDK_DEFAULT_SESSION_ID=codex-lb-default +CODEX_LB_ANTHROPIC_SDK_POOL_ENABLED=false + +# Anthropic direct API runtime (/claude/v1/messages) +CODEX_LB_ANTHROPIC_API_BASE_URL=https://api.anthropic.com +CODEX_LB_ANTHROPIC_API_VERSION=2023-06-01 +# CODEX_LB_ANTHROPIC_API_BETA= +CODEX_LB_ANTHROPIC_API_TIMEOUT_SECONDS=300 +CODEX_LB_ANTHROPIC_API_DETECT_CLI_HEADERS=true +CODEX_LB_ANTHROPIC_API_SYSTEM_PROMPT_INJECTION_MODE=minimal +CODEX_LB_ANTHROPIC_OAUTH_TOKEN_URL=https://console.anthropic.com/v1/oauth/token +CODEX_LB_ANTHROPIC_OAUTH_CLIENT_ID=9d1c250a-e61b-44d9-88ed-5944d1962f5e + +# Anthropic usage (5h/7d) +CODEX_LB_ANTHROPIC_USAGE_BASE_URL=https://api.anthropic.com +CODEX_LB_ANTHROPIC_USAGE_BETA=oauth-2025-04-20 +CODEX_LB_ANTHROPIC_USAGE_REFRESH_ENABLED=true + +# Anthropic auth discovery (Linux-only POC) +CODEX_LB_ANTHROPIC_CREDENTIALS_DISCOVERY_ENABLED=true +# CODEX_LB_ANTHROPIC_CREDENTIALS_FILE=~/.claude/.credentials.json +# CODEX_LB_ANTHROPIC_CREDENTIALS_HELPER_COMMAND= +# Explicit override (if set, used before discovery) +# CODEX_LB_ANTHROPIC_USAGE_BEARER_TOKEN=sk-ant-oat01-... +CODEX_LB_ANTHROPIC_AUTO_DISCOVER_ORG=false +CODEX_LB_ANTHROPIC_CREDENTIALS_CACHE_SECONDS=60 +CODEX_LB_ANTHROPIC_DEFAULT_ACCOUNT_ID=anthropic_default +CODEX_LB_ANTHROPIC_DEFAULT_ACCOUNT_EMAIL=anthropic@local +CODEX_LB_ANTHROPIC_DEFAULT_PLAN_TYPE=pro + # Timeouts (seconds) CODEX_LB_UPSTREAM_CONNECT_TIMEOUT_SECONDS=30 CODEX_LB_STREAM_IDLE_TIMEOUT_SECONDS=300 diff --git a/README.md b/README.md index 564c7423..d7d3fca7 100644 --- a/README.md +++ b/README.md @@ -222,6 +222,57 @@ print(response.choices[0].message.content) +## Anthropic Messages Mode (POC) + +`codex-lb` serves two Anthropic-compatible routes: + +- `POST /claude/v1/messages` (direct OAuth-backed API proxy to `api.anthropic.com/v1/messages`) +- `POST /claude-sdk/v1/messages` (local Claude SDK runtime via `claude-agent-sdk`) + +Start the server (OpenAI routes and both Anthropic routes are enabled): + +```bash +uv run fastapi run app/main.py --host 0.0.0.0 --port 2455 +``` + +Prerequisites: + +```bash +claude /login +uv sync +``` + +Usage windows source (Linux-only auto-discovery + optional overrides): + +- Set `CODEX_LB_ANTHROPIC_USAGE_REFRESH_ENABLED=true` to enable 5h/7d usage ingestion. +- Auto-discovery from local Claude credentials (`claude login`) is enabled by default for 5h/7d usage ingestion. +- Usage polling calls Anthropic OAuth usage API (`/api/oauth/usage`) with the configured OAuth beta header. +- You can override usage auth explicitly with: + +```bash +export CODEX_LB_ANTHROPIC_USAGE_BEARER_TOKEN="sk-ant-oat01-..." +``` + +Example request (API route): + +```bash +curl -sS http://127.0.0.1:2455/claude/v1/messages \ + -H 'content-type: application/json' \ + -H 'anthropic-version: 2023-06-01' \ + -d '{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 1024, + "messages": [{"role":"user","content":"Hello"}] + }' +``` + +Notes: + +- `/claude/v1/messages` sends requests directly to `api.anthropic.com/v1/messages` using discovered or configured OAuth credentials. +- `/claude-sdk/v1/messages` runs generation through the local Claude SDK runtime. +- API compatibility for `/claude/v1/messages` includes request normalization and optional CLI header/system-prompt parity helpers. +- OpenAI compatibility routes (`/v1/responses`, `/v1/chat/completions`) stay available in the same server instance. + ## API Key Authentication API key auth is **disabled by default** — the proxy is open to any client. Enable it in **Settings → API Key Auth** on the dashboard. diff --git a/app/core/auth/anthropic_credentials.py b/app/core/auth/anthropic_credentials.py new file mode 100644 index 00000000..bf79ee51 --- /dev/null +++ b/app/core/auth/anthropic_credentials.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import aiohttp + +from app.core.clients.http import get_http_client +from app.core.config.settings import Settings, get_settings +from app.db.models import Account + +logger = logging.getLogger(__name__) + +_TOKEN_PREFIX = "sk-ant-oat" +_ORG_KEY_CANDIDATES = { + "organization_id", + "organizationid", + "org_id", + "orgid", + "active_org_id", + "activeorgid", + "active_organization_id", + "activeorganizationid", +} + +_cache_lock = asyncio.Lock() +_cached_credentials: AnthropicCredentials | None = None +_cached_at_monotonic = 0.0 + + +@dataclass(frozen=True, slots=True) +class AnthropicCredentials: + bearer_token: str + org_id: str | None + source: str + refresh_token: str | None = None + expires_at_ms: int | None = None + source_path: Path | None = None + + +@dataclass(frozen=True, slots=True) +class AnthropicAuthFile: + access_token: str + refresh_token: str | None + org_id: str | None + expires_at_ms: int | None + email: str | None + + +async def resolve_anthropic_credentials(*, force_refresh: bool = False) -> AnthropicCredentials | None: + settings = get_settings() + token_override = _normalize_secret(settings.anthropic_usage_bearer_token) + org_override = _normalize_identifier(settings.anthropic_org_id) + if token_override: + org_id = org_override + if org_id is None and settings.anthropic_auto_discover_org: + org_id = await _discover_org_id(token_override, settings) + return AnthropicCredentials( + bearer_token=token_override, + org_id=org_id, + source="env", + ) + + if not settings.anthropic_credentials_discovery_enabled: + return None + if not _is_linux(): + return None + + ttl_seconds = settings.anthropic_credentials_cache_seconds + if ttl_seconds > 0 and not force_refresh: + now = time.monotonic() + if _cached_credentials is not None and now - _cached_at_monotonic < ttl_seconds: + return _cached_credentials + + async with _cache_lock: + if ttl_seconds > 0 and not force_refresh: + now = time.monotonic() + if _cached_credentials is not None and now - _cached_at_monotonic < ttl_seconds: + return _cached_credentials + + resolved = await _resolve_uncached(settings, org_override=org_override) + _set_cache(resolved) + return resolved + + +def clear_anthropic_credentials_cache() -> None: + global _cached_credentials, _cached_at_monotonic + _cached_credentials = None + _cached_at_monotonic = 0.0 + + +def credentials_from_account(account: Account) -> AnthropicCredentials | None: + try: + from app.core.crypto import TokenEncryptor + + encryptor = TokenEncryptor() + encrypted_access = bytes(account.access_token_encrypted) + encrypted_refresh = bytes(account.refresh_token_encrypted) + access_token = _normalize_secret(encryptor.decrypt(encrypted_access)) + if access_token is None or not access_token.startswith(_TOKEN_PREFIX): + return None + refresh_token = _normalize_secret(encryptor.decrypt(encrypted_refresh)) + except Exception: + return None + + return AnthropicCredentials( + bearer_token=access_token, + org_id=None, + source=f"db-account:{account.id}", + refresh_token=refresh_token, + ) + + +def parse_anthropic_auth_json(raw: bytes) -> AnthropicAuthFile: + data = json.loads(raw) + structured = _extract_structured_credentials(data) + if structured is None: + raise ValueError("Unable to extract Anthropic OAuth credentials") + + email = _extract_email(data) + return AnthropicAuthFile( + access_token=structured.bearer_token, + refresh_token=structured.refresh_token, + org_id=structured.org_id, + expires_at_ms=structured.expires_at_ms, + email=email, + ) + + +def _set_cache(value: AnthropicCredentials | None) -> None: + global _cached_credentials, _cached_at_monotonic + _cached_credentials = value + _cached_at_monotonic = time.monotonic() + + +def _parse_int(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + return int(stripped) + except ValueError: + return None + return None + + +def _parse_expires_at_ms(payload: dict[str, Any]) -> int | None: + expires_at = _parse_int(payload.get("expires_at") or payload.get("expiresAt")) + if expires_at is not None: + if expires_at < 10_000_000_000: + return expires_at * 1000 + return expires_at + + expires_in = _parse_int(payload.get("expires_in") or payload.get("expiresIn")) + if expires_in is None: + return None + now_seconds = int(time.time()) + return (now_seconds + expires_in) * 1000 + + +async def _safe_json_response(resp: aiohttp.ClientResponse) -> dict[str, Any]: + try: + data = await resp.json(content_type=None) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +async def _resolve_uncached(settings: Settings, *, org_override: str | None) -> AnthropicCredentials | None: + discovered = _discover_from_files(settings) + if discovered is None: + discovered = _discover_from_helper_command(settings) + if discovered is None: + return None + + token = discovered.bearer_token + org_id = org_override or discovered.org_id + if org_id is None and settings.anthropic_auto_discover_org: + org_id = await _discover_org_id(token, settings) + return AnthropicCredentials( + bearer_token=token, + org_id=org_id, + source=discovered.source, + refresh_token=discovered.refresh_token, + expires_at_ms=discovered.expires_at_ms, + source_path=discovered.source_path, + ) + + +async def refresh_anthropic_access_token( + credentials: AnthropicCredentials, +) -> AnthropicCredentials | None: + refresh_token = _normalize_secret(credentials.refresh_token) + if not refresh_token: + return None + + settings = get_settings() + timeout = aiohttp.ClientTimeout(total=settings.oauth_timeout_seconds) + payload = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": settings.anthropic_oauth_client_id, + } + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + try: + async with get_http_client().session.post( + settings.anthropic_oauth_token_url, + json=payload, + headers=headers, + timeout=timeout, + ) as response: + data = await _safe_json_response(response) + if response.status >= 400: + return None + except Exception: + return None + + access_token = _normalize_secret(_read_string(data, "access_token")) + if access_token is None: + return None + + new_refresh_token = _normalize_secret(_read_string(data, "refresh_token")) or refresh_token + expires_at_ms = _parse_expires_at_ms(data) + refreshed = AnthropicCredentials( + bearer_token=access_token, + org_id=credentials.org_id, + source=f"{credentials.source}:refreshed", + refresh_token=new_refresh_token, + expires_at_ms=expires_at_ms, + source_path=credentials.source_path, + ) + _set_cache(refreshed) + return refreshed + + +@dataclass(frozen=True, slots=True) +class _RawDiscoveredCredentials: + bearer_token: str + org_id: str | None + source: str + refresh_token: str | None = None + expires_at_ms: int | None = None + source_path: Path | None = None + + +def _discover_from_files(settings: Settings) -> _RawDiscoveredCredentials | None: + candidates = _credential_candidates(settings) + for path in candidates: + if not path.is_file(): + continue + payload = _load_json_file(path) + if payload is None: + continue + structured = _extract_structured_credentials(payload) + token = structured.bearer_token if structured is not None else _extract_token(payload) + if token is None: + continue + org_id = structured.org_id if structured is not None else _extract_org_id(payload) + return _RawDiscoveredCredentials( + bearer_token=token, + org_id=org_id, + source=f"file:{path}", + refresh_token=structured.refresh_token if structured is not None else None, + expires_at_ms=structured.expires_at_ms if structured is not None else None, + source_path=path, + ) + return None + + +def _discover_from_helper_command(settings: Settings) -> _RawDiscoveredCredentials | None: + command = (settings.anthropic_credentials_helper_command or "").strip() + if not command: + return None + try: + completed = subprocess.run( + command, + shell=True, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=max(1.0, settings.upstream_connect_timeout_seconds), + ) + except Exception: + logger.warning("anthropic_credentials_helper_failed", exc_info=True) + return None + + if completed.returncode != 0: + logger.warning( + "anthropic_credentials_helper_nonzero return_code=%s stderr=%s", + completed.returncode, + completed.stderr.strip(), + ) + return None + + parsed = _parse_helper_output(completed.stdout) + if parsed is None: + return None + return _RawDiscoveredCredentials( + bearer_token=parsed.bearer_token, + org_id=parsed.org_id, + source="helper-command", + refresh_token=parsed.refresh_token, + expires_at_ms=parsed.expires_at_ms, + ) + + +@dataclass(frozen=True, slots=True) +class _HelperCredentials: + bearer_token: str + org_id: str | None + refresh_token: str | None = None + expires_at_ms: int | None = None + + +@dataclass(frozen=True, slots=True) +class _StructuredCredentials: + bearer_token: str + org_id: str | None + refresh_token: str | None + expires_at_ms: int | None + + +def _parse_helper_output(stdout: str) -> _HelperCredentials | None: + raw = stdout.strip() + if not raw: + return None + try: + payload = json.loads(raw) + except json.JSONDecodeError: + token = _normalize_secret(raw.splitlines()[0] if raw.splitlines() else raw) + if token is None: + return None + return _HelperCredentials(bearer_token=token, org_id=None) + + if isinstance(payload, dict): + structured = _extract_structured_credentials(payload) + if structured is not None: + return _HelperCredentials( + bearer_token=structured.bearer_token, + org_id=structured.org_id, + refresh_token=structured.refresh_token, + expires_at_ms=structured.expires_at_ms, + ) + + token = _normalize_secret(_read_string(payload, "token") or _read_string(payload, "bearer_token")) + if token is None: + token = _extract_token(payload) + if token is None: + return None + org_id = _normalize_identifier( + _read_string(payload, "org_id") + or _read_string(payload, "organization_id") + or _read_string(payload, "orgId") + ) + if org_id is None: + org_id = _extract_org_id(payload) + return _HelperCredentials(bearer_token=token, org_id=org_id) + return None + + +def _credential_candidates(settings: Settings) -> list[Path]: + candidates: list[Path] = [] + if settings.anthropic_credentials_file is not None: + candidates.append(settings.anthropic_credentials_file) + + home = Path.home() + candidates.extend( + [ + home / ".claude/.credentials.json", + home / ".claude/credentials.json", + home / ".config/claude/.credentials.json", + home / ".config/claude/credentials.json", + ] + ) + return candidates + + +def _load_json_file(path: Path) -> Any | None: + try: + raw = path.read_text(encoding="utf-8") + except OSError: + return None + try: + return json.loads(raw) + except json.JSONDecodeError: + logger.warning("anthropic_credentials_invalid_json path=%s", path) + return None + + +def _extract_token(payload: Any) -> str | None: + for value in _walk_strings(payload): + token = _normalize_secret(value) + if token is not None and token.startswith(_TOKEN_PREFIX): + return token + return None + + +def _extract_org_id(payload: Any) -> str | None: + if isinstance(payload, dict): + for key, value in payload.items(): + normalized_key = key.strip().lower().replace("-", "_") + if normalized_key in _ORG_KEY_CANDIDATES: + org_value = _normalize_identifier(value if isinstance(value, str) else None) + if org_value: + return org_value + nested = _extract_org_id(value) + if nested: + return nested + return None + if isinstance(payload, list): + for item in payload: + nested = _extract_org_id(item) + if nested: + return nested + return None + + +def _extract_structured_credentials(payload: Any) -> _StructuredCredentials | None: + if not isinstance(payload, dict): + return None + + claude_oauth = payload.get("claudeAiOauth") + if isinstance(claude_oauth, dict): + access_token = _normalize_secret(_read_string(claude_oauth, "accessToken")) + if access_token is not None: + refresh_token = _normalize_secret(_read_string(claude_oauth, "refreshToken")) + expires_at_ms = _parse_int(claude_oauth.get("expiresAt")) + org_id = _extract_org_id(payload) + return _StructuredCredentials( + bearer_token=access_token, + org_id=org_id, + refresh_token=refresh_token, + expires_at_ms=expires_at_ms, + ) + + session = payload.get("session") + if isinstance(session, dict): + oauth = session.get("oauth") + if isinstance(oauth, dict): + access_token = _normalize_secret( + _read_string(oauth, "token") + or _read_string(oauth, "access_token") + or _read_string(oauth, "accessToken") + ) + if access_token is not None: + refresh_token = _normalize_secret( + _read_string(oauth, "refresh_token") or _read_string(oauth, "refreshToken") + ) + expires_at_ms = _parse_expires_at_ms(oauth) + org_id = _extract_org_id(payload) + return _StructuredCredentials( + bearer_token=access_token, + org_id=org_id, + refresh_token=refresh_token, + expires_at_ms=expires_at_ms, + ) + + return None + + +def _extract_email(payload: Any) -> str | None: + if isinstance(payload, dict): + for key, value in payload.items(): + normalized_key = key.strip().lower().replace("-", "_") + if normalized_key == "email" and isinstance(value, str): + email = value.strip() + if "@" in email: + return email + nested = _extract_email(value) + if nested is not None: + return nested + return None + if isinstance(payload, list): + for item in payload: + nested = _extract_email(item) + if nested is not None: + return nested + return None + + +def _walk_strings(value: Any): + if isinstance(value, str): + yield value + return + if isinstance(value, dict): + for entry in value.values(): + yield from _walk_strings(entry) + return + if isinstance(value, list): + for entry in value: + yield from _walk_strings(entry) + + +def _read_string(payload: dict[str, Any], key: str) -> str | None: + value = payload.get(key) + if isinstance(value, str): + return value + return None + + +def _normalize_secret(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + if not stripped: + return None + return stripped + + +def _normalize_identifier(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + if not stripped: + return None + return stripped + + +def _is_linux() -> bool: + return sys.platform.startswith("linux") + + +async def _discover_org_id(token: str, settings: Settings) -> str | None: + url = f"{settings.anthropic_usage_base_url.rstrip('/')}/api/organizations" + timeout = aiohttp.ClientTimeout(total=settings.usage_fetch_timeout_seconds) + headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/json", + } + try: + async with get_http_client().session.get(url, headers=headers, timeout=timeout) as resp: + if resp.status >= 400: + return None + payload = await resp.json(content_type=None) + except Exception: + return None + + return _extract_org_id_from_orgs_payload(payload) + + +def _extract_org_id_from_orgs_payload(payload: Any) -> str | None: + objects: list[Any] = [] + if isinstance(payload, list): + objects.extend(payload) + elif isinstance(payload, dict): + objects.append(payload) + for key in ("organizations", "data", "items"): + value = payload.get(key) + if isinstance(value, list): + objects.extend(value) + + for obj in objects: + if not isinstance(obj, dict): + continue + for key in ("id", "uuid", "organization_id", "organizationId"): + value = obj.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + nested = _extract_org_id(obj) + if nested: + return nested + return None diff --git a/app/core/auth/dependencies.py b/app/core/auth/dependencies.py index 7ed81ad8..13b92112 100644 --- a/app/core/auth/dependencies.py +++ b/app/core/auth/dependencies.py @@ -30,6 +30,10 @@ def set_dashboard_error_format(request: Request) -> None: request.state.error_format = "dashboard" +def set_anthropic_error_format(request: Request) -> None: + request.state.error_format = "anthropic" + + # --- Proxy API key auth --- diff --git a/app/core/clients/anthropic_api_proxy.py b/app/core/clients/anthropic_api_proxy.py new file mode 100644 index 00000000..0a20d2c4 --- /dev/null +++ b/app/core/clients/anthropic_api_proxy.py @@ -0,0 +1,717 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import socket +import tempfile +import time +from collections import deque +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import aiohttp +from aiohttp import web + +from app.core.auth.anthropic_credentials import ( + AnthropicCredentials, + refresh_anthropic_access_token, + resolve_anthropic_credentials, +) +from app.core.clients.http import get_http_client +from app.core.config.settings import get_settings +from app.core.types import JsonValue + +from .anthropic_proxy import AnthropicProxyError, anthropic_error_payload + +_IGNORE_INBOUND_HEADERS = { + "authorization", + "host", + "content-length", + "x-api-key", + "transfer-encoding", + "connection", +} + +_CACHE_TTL_SECONDS = 900 +_MAX_EVENT_BYTES = 2 * 1024 * 1024 +_DIAGNOSTICS_MAX_ITEMS = 500 + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _DetectedCliData: + headers: dict[str, str] + body_json: dict[str, Any] | None + captured_at: float + + +_detected_cli_cache: _DetectedCliData | None = None +_detect_lock = asyncio.Lock() +_recent_diagnostics: deque[dict[str, Any]] = deque(maxlen=_DIAGNOSTICS_MAX_ITEMS) + + +def get_recent_diagnostics(limit: int = 100) -> list[dict[str, Any]]: + if limit <= 0: + return [] + capped = min(limit, _DIAGNOSTICS_MAX_ITEMS) + return list(_recent_diagnostics)[-capped:] + + +async def create_message( + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + base_url: str | None = None, + session: aiohttp.ClientSession | None = None, + credentials: AnthropicCredentials | None = None, +) -> dict[str, JsonValue]: + request_payload = _prepare_payload(payload) + request_id = _extract_request_id(headers) + creds = await _resolve_valid_credentials(credentials) + inbound_headers = _filter_inbound_headers(headers) + preflight = _build_payload_diagnostics(request_payload) + + cli_data = await _get_detected_cli_data() + merged_headers = 0 + system_injected = False + detected_system_chars = _system_text_chars_from_detected(cli_data) + if cli_data is not None: + merged_headers = _merge_cli_headers(inbound_headers, cli_data.headers) + system_injected = _inject_system_prompt(request_payload, cli_data.body_json) + + post_mutation = _build_payload_diagnostics(request_payload) + _record_diagnostic( + { + "kind": "anthropic_api_preflight", + "request_id": request_id, + "stream": False, + "model": request_payload.get("model"), + "cli_detected": cli_data is not None, + "cli_merged_headers": merged_headers, + "system_injected": system_injected, + "detected_system_chars": detected_system_chars, + "pre": preflight, + "post": post_mutation, + } + ) + + request_headers = _build_request_headers( + inbound_headers, + access_token=creds.bearer_token, + stream=False, + ) + response_payload, status_code = await _request_json( + request_payload, + request_headers, + base_url=base_url, + session=session, + creds=creds, + ) + usage = _extract_usage_payload(response_payload) + _record_diagnostic( + { + "kind": "anthropic_api_response", + "request_id": request_id, + "stream": False, + "status_code": status_code, + "usage": usage, + "error_type": _extract_error_type(response_payload), + } + ) + if status_code >= 400: + raise AnthropicProxyError(status_code, _ensure_error_payload(response_payload, status_code)) + return response_payload + + +async def stream_messages( + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + base_url: str | None = None, + session: aiohttp.ClientSession | None = None, + credentials: AnthropicCredentials | None = None, +) -> AsyncIterator[str]: + request_payload = _prepare_payload(payload) + request_payload["stream"] = True + request_id = _extract_request_id(headers) + + creds = await _resolve_valid_credentials(credentials) + inbound_headers = _filter_inbound_headers(headers) + preflight = _build_payload_diagnostics(request_payload) + + cli_data = await _get_detected_cli_data() + merged_headers = 0 + system_injected = False + detected_system_chars = _system_text_chars_from_detected(cli_data) + if cli_data is not None: + merged_headers = _merge_cli_headers(inbound_headers, cli_data.headers) + system_injected = _inject_system_prompt(request_payload, cli_data.body_json) + + post_mutation = _build_payload_diagnostics(request_payload) + _record_diagnostic( + { + "kind": "anthropic_api_preflight", + "request_id": request_id, + "stream": True, + "model": request_payload.get("model"), + "cli_detected": cli_data is not None, + "cli_merged_headers": merged_headers, + "system_injected": system_injected, + "detected_system_chars": detected_system_chars, + "pre": preflight, + "post": post_mutation, + } + ) + + request_headers = _build_request_headers( + inbound_headers, + access_token=creds.bearer_token, + stream=True, + ) + + async for block in _stream_request( + request_payload, + request_headers, + base_url=base_url, + session=session, + creds=creds, + ): + yield block + + +def _prepare_payload(payload: dict[str, JsonValue]) -> dict[str, JsonValue]: + data = dict(payload) + if data.get("temperature") is None: + data.pop("temperature", None) + if "temperature" in data and "top_p" in data: + data.pop("top_p", None) + return data + + +def _filter_inbound_headers(headers: Mapping[str, str]) -> dict[str, str]: + filtered: dict[str, str] = {} + for key, value in headers.items(): + lower = key.lower() + if lower in _IGNORE_INBOUND_HEADERS: + continue + if lower.startswith("x-forwarded-"): + continue + if lower.startswith("cf-"): + continue + filtered[key] = value + return filtered + + +def _build_request_headers( + headers: dict[str, str], + *, + access_token: str, + stream: bool, +) -> dict[str, str]: + settings = get_settings() + request_headers = dict(headers) + request_headers["Authorization"] = f"Bearer {access_token}" + request_headers["Content-Type"] = "application/json" + request_headers["Accept"] = "text/event-stream" if stream else "application/json" + request_headers["anthropic-version"] = settings.anthropic_api_version + beta = settings.anthropic_api_beta + if beta and beta.strip(): + request_headers["anthropic-beta"] = beta.strip() + return request_headers + + +async def _resolve_valid_credentials(credentials: AnthropicCredentials | None = None) -> AnthropicCredentials: + resolved = credentials + if resolved is None: + resolved = await resolve_anthropic_credentials() + if resolved is None: + raise AnthropicProxyError( + 503, + anthropic_error_payload( + "api_error", + "Anthropic credentials not found. Set " + "CODEX_LB_ANTHROPIC_USAGE_BEARER_TOKEN or configure Claude credentials.", + ), + ) + + if _is_token_expiring_soon(resolved): + refreshed = await refresh_anthropic_access_token(resolved) + if refreshed is not None: + return refreshed + return resolved + + +def _is_token_expiring_soon(credentials: AnthropicCredentials) -> bool: + expires_at_ms = credentials.expires_at_ms + if expires_at_ms is None: + return False + now_ms = int(time.time() * 1000) + return expires_at_ms <= now_ms + 60_000 + + +async def _request_json( + payload: dict[str, JsonValue], + headers: dict[str, str], + *, + base_url: str | None, + session: aiohttp.ClientSession | None, + creds: AnthropicCredentials, +) -> tuple[dict[str, JsonValue], int]: + response, status = await _request_json_once(payload, headers, base_url=base_url, session=session) + if creds.refresh_token and _should_attempt_oauth_refresh(status, response): + refreshed = await refresh_anthropic_access_token(creds) + if refreshed is not None: + headers = dict(headers) + headers["Authorization"] = f"Bearer {refreshed.bearer_token}" + response, status = await _request_json_once(payload, headers, base_url=base_url, session=session) + return response, status + + +async def _request_json_once( + payload: dict[str, JsonValue], + headers: dict[str, str], + *, + base_url: str | None, + session: aiohttp.ClientSession | None, +) -> tuple[dict[str, JsonValue], int]: + settings = get_settings() + timeout = aiohttp.ClientTimeout(total=settings.anthropic_api_timeout_seconds) + client = session or get_http_client().session + url = f"{(base_url or settings.anthropic_api_base_url).rstrip('/')}/v1/messages" + + try: + async with client.post(url, json=payload, headers=headers, timeout=timeout) as response: + data = await _safe_json(response) + return data, response.status + except aiohttp.ClientError as exc: + raise AnthropicProxyError(502, anthropic_error_payload("api_error", f"Upstream unavailable: {exc}")) from exc + except asyncio.TimeoutError as exc: + raise AnthropicProxyError(504, anthropic_error_payload("api_error", "Anthropic API timeout")) from exc + + +async def _stream_request( + payload: dict[str, JsonValue], + headers: dict[str, str], + *, + base_url: str | None, + session: aiohttp.ClientSession | None, + creds: AnthropicCredentials, +) -> AsyncIterator[str]: + try: + async for block in _stream_request_once(payload, headers, base_url=base_url, session=session): + yield block + return + except AnthropicProxyError as exc: + if not creds.refresh_token or not _should_attempt_oauth_refresh(exc.status_code, exc.payload): + raise + + refreshed = await refresh_anthropic_access_token(creds) + if refreshed is None: + raise AnthropicProxyError( + 401, + anthropic_error_payload("authentication_error", "Anthropic authentication failed"), + ) + + retry_headers = dict(headers) + retry_headers["Authorization"] = f"Bearer {refreshed.bearer_token}" + async for block in _stream_request_once(payload, retry_headers, base_url=base_url, session=session): + yield block + + +async def _stream_request_once( + payload: dict[str, JsonValue], + headers: dict[str, str], + *, + base_url: str | None, + session: aiohttp.ClientSession | None, +) -> AsyncIterator[str]: + settings = get_settings() + timeout = aiohttp.ClientTimeout(total=None, sock_connect=settings.anthropic_api_timeout_seconds, sock_read=None) + client = session or get_http_client().session + url = f"{(base_url or settings.anthropic_api_base_url).rstrip('/')}/v1/messages" + + try: + async with client.post(url, json=payload, headers=headers, timeout=timeout) as response: + if response.status >= 400: + data = await _safe_json(response) + raise AnthropicProxyError(response.status, _ensure_error_payload(data, response.status)) + + async for event_block in _iter_sse_event_blocks(response): + if event_block.strip(): + yield event_block + except AnthropicProxyError: + raise + except aiohttp.ClientError as exc: + raise AnthropicProxyError(502, anthropic_error_payload("api_error", f"Upstream unavailable: {exc}")) from exc + except asyncio.TimeoutError as exc: + raise AnthropicProxyError(504, anthropic_error_payload("api_error", "Anthropic API timeout")) from exc + + +async def _safe_json(response: aiohttp.ClientResponse) -> dict[str, JsonValue]: + try: + data = await response.json(content_type=None) + except Exception: + text = (await response.text()).strip() + if text: + return anthropic_error_payload("api_error", text) + return {} + if isinstance(data, dict): + return data + return {} + + +def _ensure_error_payload(payload: dict[str, JsonValue], status_code: int) -> dict[str, JsonValue]: + if payload.get("type") == "error" and isinstance(payload.get("error"), dict): + return payload + message: str = f"Anthropic API error ({status_code})" + payload_message = payload.get("message") + if isinstance(payload_message, str) and payload_message: + message = payload_message + return anthropic_error_payload("api_error", message) + + +def _should_attempt_oauth_refresh(status_code: int, payload: dict[str, JsonValue]) -> bool: + if status_code == 401: + return True + if status_code != 403: + return False + + error = payload.get("error") + if not isinstance(error, dict): + return False + error_type = error.get("type") + if error_type not in {"permission_error", "authentication_error"}: + return False + message = error.get("message") + if not isinstance(message, str): + return False + lowered = message.casefold() + return "oauth" in lowered and "revoked" in lowered + + +async def _iter_sse_event_blocks(response: aiohttp.ClientResponse) -> AsyncIterator[str]: + buffer = bytearray() + async for chunk in response.content.iter_chunked(8192): + if not chunk: + continue + buffer.extend(chunk) + while True: + separator = _find_separator(buffer) + if separator is None: + if len(buffer) > _MAX_EVENT_BYTES: + raise AnthropicProxyError( + 502, + anthropic_error_payload("api_error", "Streaming event exceeded size limit"), + ) + break + index, sep_len = separator + end = index + sep_len + raw = bytes(buffer[:end]) + del buffer[:end] + yield raw.decode("utf-8", errors="replace") + if buffer: + yield bytes(buffer).decode("utf-8", errors="replace") + + +def _find_separator(buffer: bytes | bytearray) -> tuple[int, int] | None: + pos_crlf = buffer.find(b"\r\n\r\n") + pos_lf = buffer.find(b"\n\n") + options: list[tuple[int, int]] = [] + if pos_crlf >= 0: + options.append((pos_crlf, 4)) + if pos_lf >= 0: + options.append((pos_lf, 2)) + if not options: + return None + return min(options, key=lambda item: item[0]) + + +async def _get_detected_cli_data() -> _DetectedCliData | None: + settings = get_settings() + if not settings.anthropic_api_detect_cli_headers: + return None + + global _detected_cli_cache + cached = _detected_cli_cache + now = time.monotonic() + if cached is not None and now - cached.captured_at < _CACHE_TTL_SECONDS: + return cached + + async with _detect_lock: + cached = _detected_cli_cache + now = time.monotonic() + if cached is not None and now - cached.captured_at < _CACHE_TTL_SECONDS: + return cached + + detected = await _detect_cli_headers_and_body() + _detected_cli_cache = detected + return detected + + +async def _detect_cli_headers_and_body() -> _DetectedCliData | None: + cli_binary = _find_claude_binary() + if cli_binary is None: + return None + + captured_headers: dict[str, str] = {} + captured_body: dict[str, Any] | None = None + app = web.Application() + + async def handle_messages(request: web.Request) -> web.Response: + nonlocal captured_headers, captured_body + captured_headers = { + key.lower(): value + for key, value in request.headers.items() + if key.lower() not in {"host", "authorization", "x-api-key", "content-length"} + } + raw = await request.read() + try: + decoded = json.loads(raw.decode("utf-8")) if raw else None + except Exception: + decoded = None + captured_body = decoded if isinstance(decoded, dict) else None + return web.json_response( + { + "id": "msg_detect", + "type": "message", + "role": "assistant", + "model": "claude-sonnet-4-20250514", + "content": [{"type": "text", "text": "ok"}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 1, "output_tokens": 1}, + } + ) + + app.router.add_post("/v1/messages", handle_messages) + + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", port) + await site.start() + + try: + await _run_claude_detection_command(cli_binary, port) + finally: + await runner.cleanup() + + if not captured_headers: + return None + return _DetectedCliData( + headers=captured_headers, + body_json=captured_body, + captured_at=time.monotonic(), + ) + + +def _find_claude_binary() -> str | None: + settings = get_settings() + if settings.anthropic_sdk_cli_path: + candidate = Path(settings.anthropic_sdk_cli_path).expanduser() + if candidate.exists() and os.access(candidate, os.X_OK): + return str(candidate) + + path = shutil_which("claude") + return path + + +def shutil_which(binary: str) -> str | None: + from shutil import which + + return which(binary) + + +async def _run_claude_detection_command(cli_binary: str, port: int) -> None: + env = dict(os.environ) + env["ANTHROPIC_BASE_URL"] = f"http://127.0.0.1:{port}" + + with tempfile.TemporaryDirectory(prefix="codex-lb-claude-detect-") as temp_dir: + process = await asyncio.create_subprocess_exec( + cli_binary, + "test", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=temp_dir, + env=env, + ) + try: + await asyncio.wait_for(process.wait(), timeout=20) + except TimeoutError: + process.kill() + await process.wait() + + +def _merge_cli_headers(target: dict[str, str], cli_headers: dict[str, str]) -> int: + blocked = {"authorization", "x-api-key", "host", "content-length"} + merged = 0 + for key, value in cli_headers.items(): + lower = key.lower() + if lower in blocked: + continue + previous = target.get(key) + if previous != value: + merged += 1 + target[key] = value + return merged + + +def _inject_system_prompt(payload: dict[str, JsonValue], body_json: dict[str, Any] | None) -> bool: + if body_json is None: + return False + settings = get_settings() + mode = settings.anthropic_api_system_prompt_injection_mode + if mode == "none": + return False + + detected_system = body_json.get("system") + if not isinstance(detected_system, (str, list, dict)): + return False + + if mode == "minimal": + if payload.get("system") is None: + payload["system"] = detected_system + return True + return False + + previous = payload.get("system") + payload["system"] = detected_system + return previous != detected_system + + +def _extract_request_id(headers: Mapping[str, str]) -> str | None: + for key, value in headers.items(): + lower = key.lower() + if lower in {"x-request-id", "request-id"}: + request_id = value.strip() + if request_id: + return request_id + return None + + +def _system_text_chars_from_detected(cli_data: _DetectedCliData | None) -> int: + if cli_data is None or not isinstance(cli_data.body_json, dict): + return 0 + return _system_text_chars(cli_data.body_json.get("system")) + + +def _build_payload_diagnostics(payload: dict[str, JsonValue]) -> dict[str, int | str | bool | None]: + messages: Any = payload.get("messages") + tools: Any = payload.get("tools") + system: Any = payload.get("system") + max_tokens: Any = payload.get("max_tokens") + stream: Any = payload.get("stream") + message_count = len(messages) if isinstance(messages, list) else 0 + tools_count = len(tools) if isinstance(tools, list) else 0 + return { + "json_bytes": _json_size_bytes(payload), + "message_count": message_count, + "messages_text_chars": _messages_text_chars(messages), + "system_text_chars": _system_text_chars(system), + "tools_count": tools_count, + "max_tokens": _as_int(max_tokens), + "stream": bool(stream), + } + + +def _json_size_bytes(payload: dict[str, JsonValue]) -> int: + try: + return len(json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8")) + except Exception: + return -1 + + +def _messages_text_chars(messages: Any) -> int: + if not isinstance(messages, list): + return 0 + total = 0 + for item in messages: + if not isinstance(item, dict): + continue + total += _content_text_chars(item.get("content")) + return total + + +def _content_text_chars(content: Any) -> int: + if isinstance(content, str): + return len(content) + if isinstance(content, list): + total = 0 + for block in content: + if isinstance(block, str): + total += len(block) + continue + if isinstance(block, dict): + text = block.get("text") + if isinstance(text, str): + total += len(text) + return total + if isinstance(content, dict): + text = content.get("text") + if isinstance(text, str): + return len(text) + return 0 + + +def _system_text_chars(system: Any) -> int: + if isinstance(system, str): + return len(system) + if isinstance(system, list): + return sum(_system_text_chars(item) for item in system) + if isinstance(system, dict): + text = system.get("text") + if isinstance(text, str): + return len(text) + return 0 + + +def _as_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + return None + + +def _extract_usage_payload(payload: dict[str, JsonValue]) -> dict[str, int] | None: + usage = payload.get("usage") + if not isinstance(usage, dict): + return None + + out: dict[str, int] = {} + for key in ("input_tokens", "output_tokens", "cache_read_input_tokens", "cache_creation_input_tokens"): + value = usage.get(key) + parsed = _as_int(value) + if parsed is not None: + out[key] = parsed + return out or None + + +def _extract_error_type(payload: dict[str, JsonValue]) -> str | None: + error = payload.get("error") + if isinstance(error, dict): + error_type = error.get("type") + if isinstance(error_type, str): + return error_type + return None + + +def _record_diagnostic(entry: dict[str, Any]) -> None: + record = { + "ts": datetime.now(timezone.utc).isoformat(), + **entry, + } + _recent_diagnostics.append(record) + logger.warning("anthropic_api_diag %s", json.dumps(record, default=str)) diff --git a/app/core/clients/anthropic_proxy.py b/app/core/clients/anthropic_proxy.py new file mode 100644 index 00000000..0d2026e5 --- /dev/null +++ b/app/core/clients/anthropic_proxy.py @@ -0,0 +1,1007 @@ +from __future__ import annotations + +import asyncio +import importlib +import json +import logging +import uuid +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any + +from app.core.config.settings import get_settings +from app.core.types import JsonValue + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class AnthropicProxyError(Exception): + status_code: int + payload: dict[str, JsonValue] + + def __str__(self) -> str: + return f"Anthropic proxy response error ({self.status_code})" + + +@dataclass(frozen=True, slots=True) +class _PoolKey: + model: str + session_id: str + max_tokens: int | None + temperature: float | None + system_prompt: str | None + cli_path: str | None + allowed_tools_signature: str | None + permission_mode: str | None + cwd: str | None + max_thinking_tokens: int | None + mcp_servers_signature: str | None + + +@dataclass(slots=True) +class _PooledClient: + client: Any + broken: bool = False + + +class _ClientPool: + def __init__(self, max_size: int) -> None: + self._max_size = max_size + self._queue: asyncio.LifoQueue[_PooledClient] = asyncio.LifoQueue(maxsize=max_size) + self._created = 0 + self._lock = asyncio.Lock() + + async def acquire(self, create_client: Any, *, acquire_timeout_seconds: float) -> _PooledClient: + try: + return self._queue.get_nowait() + except asyncio.QueueEmpty: + pass + + should_create = False + async with self._lock: + if self._created < self._max_size: + self._created += 1 + should_create = True + + if should_create: + try: + client = await create_client() + return _PooledClient(client=client) + except Exception: + async with self._lock: + self._created = max(0, self._created - 1) + raise + + try: + return await asyncio.wait_for(self._queue.get(), timeout=acquire_timeout_seconds) + except asyncio.TimeoutError as exc: + raise AnthropicProxyError( + 503, + anthropic_error_payload( + "api_error", + "Timed out waiting for an available Claude SDK session", + ), + ) from exc + + async def release(self, pooled: _PooledClient) -> None: + if pooled.broken: + await _safe_disconnect(pooled.client) + async with self._lock: + self._created = max(0, self._created - 1) + return + + try: + self._queue.put_nowait(pooled) + except asyncio.QueueFull: + await _safe_disconnect(pooled.client) + async with self._lock: + self._created = max(0, self._created - 1) + + async def close(self) -> None: + drained = 0 + while True: + try: + pooled = self._queue.get_nowait() + except asyncio.QueueEmpty: + break + drained += 1 + await _safe_disconnect(pooled.client) + if drained: + async with self._lock: + self._created = max(0, self._created - drained) + + +class _PoolManager: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._pools: dict[_PoolKey, _ClientPool] = {} + + async def acquire( + self, + key: _PoolKey, + create_client: Any, + *, + max_size: int, + acquire_timeout_seconds: float, + ) -> _PooledClient: + async with self._lock: + pool = self._pools.get(key) + if pool is None: + pool = _ClientPool(max_size=max_size) + self._pools[key] = pool + return await pool.acquire(create_client, acquire_timeout_seconds=acquire_timeout_seconds) + + async def release(self, key: _PoolKey, pooled: _PooledClient) -> None: + async with self._lock: + pool = self._pools.get(key) + if pool is None: + await _safe_disconnect(pooled.client) + return + await pool.release(pooled) + + async def close_all(self) -> None: + async with self._lock: + pools = list(self._pools.values()) + self._pools.clear() + for pool in pools: + await pool.close() + + +_POOL_MANAGER = _PoolManager() + + +def anthropic_error_payload(error_type: str, message: str) -> dict[str, JsonValue]: + error_detail: dict[str, JsonValue] = { + "type": error_type, + "message": message, + } + return { + "type": "error", + "error": error_detail, + } + + +async def create_message( + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + base_url: str | None = None, + session: object | None = None, +) -> dict[str, JsonValue]: + del base_url, session + + message_payload = _build_sdk_query_message(payload) + session_id, session_source = _resolve_session_id_with_source(payload) + if session_id is None: + session_id = _ephemeral_session_id() + session_source = "ephemeral" + + _log_sdk_preflight( + payload, + message_payload, + headers, + stream=False, + session_id=session_id, + session_source=session_source, + ) + + try: + async with _acquire_client( + payload, + session_id=session_id, + poolable=session_source != "ephemeral", + ) as client: + await _send_query(client, message_payload, session_id=session_id) + collected = [message async for message in client.receive_response()] + except AnthropicProxyError: + raise + except Exception as exc: + raise _map_sdk_error(exc) from exc + + response_payload = _build_non_stream_response(collected, requested_model=_extract_request_model(payload)) + _log_sdk_result(response_payload, headers, stream=False) + return response_payload + + +async def stream_messages( + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + base_url: str | None = None, + session: object | None = None, +) -> AsyncIterator[str]: + del base_url, session + + message_payload = _build_sdk_query_message(payload) + session_id, session_source = _resolve_session_id_with_source(payload) + if session_id is None: + session_id = _ephemeral_session_id() + session_source = "ephemeral" + model = _extract_request_model(payload) + message_id = f"msg_{uuid.uuid4().hex}" + yielded_start = False + content_block_index = 0 + emitted_content = False + + _log_sdk_preflight( + payload, + message_payload, + headers, + stream=True, + session_id=session_id, + session_source=session_source, + ) + + try: + async with _acquire_client( + payload, + session_id=session_id, + poolable=session_source != "ephemeral", + ) as client: + await _send_query(client, message_payload, session_id=session_id) + + async for sdk_message in client.receive_response(): + message_type = type(sdk_message).__name__ + + if not yielded_start: + yielded_start = True + yield _to_sse( + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": _stream_usage_defaults({}), + }, + } + ) + + if message_type == "AssistantMessage": + for block in _extract_content_blocks(sdk_message): + block_type = block.get("type") + if block_type == "text": + text_value = block.get("text") + if not isinstance(text_value, str) or not text_value: + continue + yield _to_sse( + { + "type": "content_block_start", + "index": content_block_index, + "content_block": {"type": "text", "text": ""}, + } + ) + yield _to_sse( + { + "type": "content_block_delta", + "index": content_block_index, + "delta": {"type": "text_delta", "text": text_value}, + } + ) + yield _to_sse({"type": "content_block_stop", "index": content_block_index}) + content_block_index += 1 + emitted_content = True + elif block_type in {"tool_use", "tool_result"}: + yield _to_sse( + { + "type": "content_block_start", + "index": content_block_index, + "content_block": block, + } + ) + yield _to_sse({"type": "content_block_stop", "index": content_block_index}) + content_block_index += 1 + emitted_content = True + + if message_type == "ResultMessage": + if getattr(sdk_message, "is_error", False): + error_message = _extract_result_error_message(sdk_message) + yield _to_sse( + anthropic_error_payload( + "api_error", + error_message, + ) + ) + return + + if not emitted_content: + fallback_text = _extract_result_text(sdk_message) + if fallback_text: + yield _to_sse( + { + "type": "content_block_start", + "index": content_block_index, + "content_block": {"type": "text", "text": ""}, + } + ) + yield _to_sse( + { + "type": "content_block_delta", + "index": content_block_index, + "delta": {"type": "text_delta", "text": fallback_text}, + } + ) + yield _to_sse({"type": "content_block_stop", "index": content_block_index}) + content_block_index += 1 + emitted_content = True + + stop_reason = _extract_stop_reason(sdk_message) + usage = _stream_usage_defaults(_extract_usage_fields(sdk_message)) + yield _to_sse( + { + "type": "message_delta", + "delta": { + "stop_reason": stop_reason, + "stop_sequence": None, + }, + "usage": usage, + } + ) + yield _to_sse({"type": "message_stop"}) + _log_sdk_stream_result(usage, headers) + return + + if yielded_start: + yield _to_sse( + anthropic_error_payload( + "api_error", + "Claude SDK stream closed without final result", + ) + ) + except AnthropicProxyError: + raise + except Exception as exc: + raise _map_sdk_error(exc) from exc + + +def parse_sse_data_payload(event_block: str) -> dict[str, JsonValue] | None: + data_lines: list[str] = [] + for raw_line in event_block.splitlines(): + if not raw_line or raw_line.startswith(":"): + continue + if not raw_line.startswith("data:"): + continue + value = raw_line[5:] + if value.startswith(" "): + value = value[1:] + data_lines.append(value) + + if not data_lines: + return None + joined = "\n".join(data_lines).strip() + if not joined or joined == "[DONE]": + return None + try: + payload = json.loads(joined) + except json.JSONDecodeError: + return None + if isinstance(payload, dict): + return payload + return None + + +async def close_anthropic_client_pools() -> None: + await _POOL_MANAGER.close_all() + + +def _ephemeral_session_id() -> str: + return f"codexlb_{uuid.uuid4().hex}" + + +@asynccontextmanager +async def _acquire_client( + payload: dict[str, JsonValue], + *, + session_id: str, + poolable: bool, +) -> AsyncIterator[Any]: + sdk = _require_sdk() + settings = get_settings() + if not settings.anthropic_sdk_pool_enabled or not poolable: + options = _build_sdk_options(payload) + client = sdk.ClaudeSDKClient(options) + await client.connect() + try: + yield client + finally: + await _safe_disconnect(client) + return + + key = _pool_key_from_payload(payload, session_id=session_id) + + async def _create_client() -> Any: + options = _build_sdk_options(payload) + client = sdk.ClaudeSDKClient(options) + await client.connect() + return client + + pooled = await _POOL_MANAGER.acquire( + key, + _create_client, + max_size=settings.anthropic_sdk_pool_size, + acquire_timeout_seconds=settings.anthropic_sdk_pool_acquire_timeout_seconds, + ) + + broken = False + try: + yield pooled.client + except Exception: + broken = True + raise + finally: + if broken: + pooled.broken = True + await _POOL_MANAGER.release(key, pooled) + + +def _pool_key_from_payload(payload: dict[str, JsonValue], *, session_id: str) -> _PoolKey: + settings = get_settings() + model = _extract_request_model(payload) + max_tokens = _as_int(payload.get("max_tokens")) + temperature_raw = payload.get("temperature") + temperature = float(temperature_raw) if isinstance(temperature_raw, (int, float)) else None + system_prompt = _extract_system_prompt(payload) + allowed_tools_signature = _signature_for_json(payload.get("allowed_tools")) + permission_mode = _normalize_str(payload.get("permission_mode")) + cwd = _normalize_str(payload.get("cwd")) + max_thinking_tokens = _as_int(payload.get("max_thinking_tokens")) + mcp_servers_signature = _signature_for_json(payload.get("mcp_servers")) + + return _PoolKey( + model=model, + session_id=session_id, + max_tokens=max_tokens, + temperature=temperature, + system_prompt=system_prompt, + cli_path=_normalize_str(settings.anthropic_sdk_cli_path), + allowed_tools_signature=allowed_tools_signature, + permission_mode=permission_mode, + cwd=cwd, + max_thinking_tokens=max_thinking_tokens, + mcp_servers_signature=mcp_servers_signature, + ) + + +def _signature_for_json(value: JsonValue | None) -> str | None: + if value is None: + return None + try: + return json.dumps(value, ensure_ascii=True, separators=(",", ":"), sort_keys=True) + except Exception: + return None + + +def _normalize_str(value: object) -> str | None: + if isinstance(value, str): + stripped = value.strip() + return stripped or None + return None + + +def _as_int(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + return None + + +async def _safe_disconnect(client: Any) -> None: + try: + await client.disconnect() + except Exception: + pass + + +def _require_sdk() -> Any: + try: + sdk = importlib.import_module("claude_agent_sdk") + except Exception as exc: + raise AnthropicProxyError( + 503, + anthropic_error_payload( + "api_error", + "claude-agent-sdk is required for Anthropic provider mode", + ), + ) from exc + + required_attrs = ("ClaudeSDKClient", "ClaudeAgentOptions") + for attr in required_attrs: + if not hasattr(sdk, attr): + raise AnthropicProxyError( + 503, + anthropic_error_payload( + "api_error", + f"claude-agent-sdk missing required attribute: {attr}", + ), + ) + return sdk + + +def _build_sdk_options(payload: dict[str, JsonValue]) -> Any: + sdk = _require_sdk() + options = sdk.ClaudeAgentOptions() + + model = _extract_request_model(payload) + if model and hasattr(options, "model"): + setattr(options, "model", model) + + max_tokens = payload.get("max_tokens") + if isinstance(max_tokens, int) and max_tokens > 0 and hasattr(options, "max_tokens"): + setattr(options, "max_tokens", max_tokens) + + temperature = payload.get("temperature") + if isinstance(temperature, (int, float)) and hasattr(options, "temperature"): + setattr(options, "temperature", float(temperature)) + + system_prompt = _extract_system_prompt(payload) + if system_prompt and hasattr(options, "system_prompt"): + setattr(options, "system_prompt", system_prompt) + + session_id = _resolve_session_id(payload) + if session_id and hasattr(options, "continue_conversation"): + setattr(options, "continue_conversation", True) + + settings = get_settings() + cli_path = settings.anthropic_sdk_cli_path + if cli_path: + for attr in ("path_to_claude_code_executable", "pathToClaudeCodeExecutable", "cli_path"): + if hasattr(options, attr): + setattr(options, attr, cli_path) + break + + passthrough_option_keys = ( + "allowed_tools", + "permission_mode", + "cwd", + "max_thinking_tokens", + "mcp_servers", + ) + for key in passthrough_option_keys: + value = payload.get(key) + if value is not None and hasattr(options, key): + setattr(options, key, value) + + return options + + +def _build_sdk_query_message(payload: dict[str, JsonValue]) -> dict[str, JsonValue]: + prompt = _build_prompt_from_messages(payload) + return { + "type": "user", + "message": { + "role": "user", + "content": prompt, + }, + } + + +def _build_prompt_from_messages(payload: dict[str, JsonValue]) -> str: + raw_messages = payload.get("messages") + if not isinstance(raw_messages, list): + raise AnthropicProxyError( + 400, + anthropic_error_payload("invalid_request_error", "messages must be a list"), + ) + + prompt_parts: list[str] = [] + for raw_message in raw_messages: + if not isinstance(raw_message, dict): + continue + role = raw_message.get("role") + if not isinstance(role, str): + continue + content = _content_to_text(raw_message.get("content")) + if not content: + continue + prompt_parts.append(f"{role}: {content}") + + if not prompt_parts: + raise AnthropicProxyError( + 400, + anthropic_error_payload("invalid_request_error", "messages must include user text content"), + ) + + return "\n\n".join(prompt_parts) + + +def _content_to_text(value: JsonValue) -> str: + if isinstance(value, str): + return value + if isinstance(value, list): + chunks: list[str] = [] + for item in value: + if isinstance(item, str): + chunks.append(item) + continue + if not isinstance(item, dict): + continue + block_type = item.get("type") + if block_type == "text": + text_value = item.get("text") + if isinstance(text_value, str) and text_value: + chunks.append(text_value) + elif block_type == "tool_result": + result_content = item.get("content") + if isinstance(result_content, str) and result_content: + chunks.append(result_content) + return "\n".join(chunks) + if isinstance(value, dict): + text_value = value.get("text") + if isinstance(text_value, str): + return text_value + return "" + + +def _extract_system_prompt(payload: dict[str, JsonValue]) -> str | None: + system_value = payload.get("system") + if system_value is None: + return None + if isinstance(system_value, str): + stripped = system_value.strip() + return stripped or None + if isinstance(system_value, list): + parts: list[str] = [] + for block in system_value: + if isinstance(block, str) and block.strip(): + parts.append(block.strip()) + continue + if isinstance(block, dict) and block.get("type") == "text": + text_value = block.get("text") + if isinstance(text_value, str) and text_value.strip(): + parts.append(text_value.strip()) + if parts: + return "\n\n".join(parts) + return None + + +def _extract_request_model(payload: dict[str, JsonValue]) -> str: + model = payload.get("model") + if isinstance(model, str) and model.strip(): + return model.strip() + raise AnthropicProxyError( + 400, + anthropic_error_payload("invalid_request_error", "model is required"), + ) + + +def _resolve_session_id(payload: dict[str, JsonValue]) -> str | None: + session_id, _ = _resolve_session_id_with_source(payload) + return session_id + + +def _resolve_session_id_with_source(payload: dict[str, JsonValue]) -> tuple[str | None, str]: + direct_session_id = payload.get("session_id") + if isinstance(direct_session_id, str) and direct_session_id.strip(): + return direct_session_id.strip(), "session_id" + + metadata = payload.get("metadata") + if isinstance(metadata, dict): + metadata_session_id = metadata.get("session_id") + if isinstance(metadata_session_id, str) and metadata_session_id.strip(): + return metadata_session_id.strip(), "metadata" + + default_session_id = get_settings().anthropic_sdk_default_session_id + if default_session_id and default_session_id.strip(): + return default_session_id.strip(), "default" + return None, "none" + + +def _extract_request_id(headers: Mapping[str, str]) -> str: + for key, value in headers.items(): + if key.lower() in {"x-request-id", "request-id"} and value.strip(): + return value.strip() + return "-" + + +def _json_size_bytes(payload: dict[str, JsonValue]) -> int: + try: + return len(json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8")) + except Exception: + return -1 + + +def _payload_message_count(payload: dict[str, JsonValue]) -> int: + messages = payload.get("messages") + if isinstance(messages, list): + return len(messages) + return 0 + + +def _system_text_chars_for_payload(payload: dict[str, JsonValue]) -> int: + system = payload.get("system") + if isinstance(system, str): + return len(system) + if isinstance(system, list): + total = 0 + for item in system: + if isinstance(item, str): + total += len(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + total += len(text) + return total + return 0 + + +def _prompt_chars_from_message_payload(message_payload: dict[str, JsonValue]) -> int: + message = message_payload.get("message") + if not isinstance(message, dict): + return 0 + content = message.get("content") + if isinstance(content, str): + return len(content) + return 0 + + +def _session_tail(session_id: str) -> str: + if len(session_id) <= 8: + return session_id + return session_id[-8:] + + +def _log_sdk_preflight( + payload: dict[str, JsonValue], + message_payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + stream: bool, + session_id: str, + session_source: str, +) -> None: + logger.warning( + "anthropic_sdk_diag stage=preflight " + "request_id=%s stream=%s model=%s session_source=%s session_tail=%s " + "payload_bytes=%s message_count=%s prompt_chars=%s system_chars=%s", + _extract_request_id(headers), + stream, + _extract_request_model(payload), + session_source, + _session_tail(session_id), + _json_size_bytes(payload), + _payload_message_count(payload), + _prompt_chars_from_message_payload(message_payload), + _system_text_chars_for_payload(payload), + ) + + +def _log_sdk_result(response_payload: dict[str, JsonValue], headers: Mapping[str, str], *, stream: bool) -> None: + usage = response_payload.get("usage") + if not isinstance(usage, dict): + usage = {} + logger.warning( + "anthropic_sdk_diag stage=result " + "request_id=%s stream=%s input_tokens=%s cached_input_tokens=%s output_tokens=%s", + _extract_request_id(headers), + stream, + usage.get("input_tokens"), + usage.get("cache_read_input_tokens"), + usage.get("output_tokens"), + ) + + +def _log_sdk_stream_result(usage: dict[str, JsonValue], headers: Mapping[str, str]) -> None: + logger.warning( + "anthropic_sdk_diag stage=result " + "request_id=%s stream=%s input_tokens=%s cached_input_tokens=%s output_tokens=%s", + _extract_request_id(headers), + True, + usage.get("input_tokens"), + usage.get("cache_read_input_tokens"), + usage.get("output_tokens"), + ) + + +async def _send_query(client: Any, message_payload: dict[str, JsonValue], *, session_id: str | None) -> None: + async def message_iter() -> AsyncIterator[dict[str, JsonValue]]: + yield message_payload + + if session_id: + await client.query(message_iter(), session_id=session_id) + else: + await client.query(message_iter()) + + +def _build_non_stream_response( + sdk_messages: list[Any], + *, + requested_model: str, +) -> dict[str, JsonValue]: + content_blocks: list[JsonValue] = [] + result_message: Any | None = None + + for sdk_message in sdk_messages: + message_type = type(sdk_message).__name__ + if message_type == "AssistantMessage": + content_blocks.extend(_extract_content_blocks(sdk_message)) + if message_type == "ResultMessage": + result_message = sdk_message + + if result_message is None: + raise AnthropicProxyError( + 502, + anthropic_error_payload("api_error", "Claude SDK did not return a result message"), + ) + + if getattr(result_message, "is_error", False): + raise AnthropicProxyError( + 502, + anthropic_error_payload("api_error", _extract_result_error_message(result_message)), + ) + + if not content_blocks: + fallback_text = _extract_result_text(result_message) + if fallback_text: + content_blocks.append({"type": "text", "text": fallback_text}) + + usage = _extract_usage_fields(result_message) + return { + "id": f"msg_{uuid.uuid4().hex}", + "type": "message", + "role": "assistant", + "model": requested_model, + "content": content_blocks, + "stop_reason": _extract_stop_reason(result_message), + "stop_sequence": None, + "usage": usage, + } + + +def _extract_content_blocks(message: Any) -> list[dict[str, JsonValue]]: + content_value = getattr(message, "content", None) + if not isinstance(content_value, list): + return [] + + blocks: list[dict[str, JsonValue]] = [] + for block in content_value: + block_type = getattr(block, "type", None) + if block_type == "text": + text_value = getattr(block, "text", None) + if isinstance(text_value, str): + blocks.append({"type": "text", "text": text_value}) + elif block_type == "tool_use": + block_id = getattr(block, "id", None) + name = getattr(block, "name", None) + tool_input = getattr(block, "input", None) + if isinstance(block_id, str) and isinstance(name, str) and isinstance(tool_input, dict): + blocks.append( + { + "type": "tool_use", + "id": block_id, + "name": name, + "input": tool_input, + } + ) + elif block_type == "tool_result": + tool_use_id = getattr(block, "tool_use_id", None) + block_content = getattr(block, "content", None) + is_error = getattr(block, "is_error", None) + if isinstance(tool_use_id, str): + tool_result_block: dict[str, JsonValue] = { + "type": "tool_result", + "tool_use_id": tool_use_id, + } + if isinstance(block_content, str): + tool_result_block["content"] = block_content + elif isinstance(block_content, list): + normalized_content = _normalize_json_list(block_content) + if normalized_content is not None: + tool_result_block["content"] = normalized_content + if isinstance(is_error, bool): + tool_result_block["is_error"] = is_error + blocks.append(tool_result_block) + return blocks + + +def _normalize_json_list(value: list[Any]) -> list[JsonValue] | None: + normalized: list[JsonValue] = [] + for item in value: + if isinstance(item, (str, int, float, bool)) or item is None: + normalized.append(item) + elif isinstance(item, Mapping): + normalized.append(dict(item)) + elif isinstance(item, list): + nested = _normalize_json_list(item) + if nested is None: + return None + normalized.append(nested) + else: + return None + return normalized + + +def _extract_usage_fields(result_message: Any) -> dict[str, JsonValue]: + usage_source = getattr(result_message, "usage", None) + usage_input = _extract_usage_int(usage_source, "input_tokens") + usage_output = _extract_usage_int(usage_source, "output_tokens") + usage_cached = _extract_usage_int(usage_source, "cache_read_input_tokens") + usage_cache_creation = _extract_usage_int(usage_source, "cache_creation_input_tokens") + + usage: dict[str, JsonValue] = {} + if usage_input is not None: + usage["input_tokens"] = usage_input + if usage_output is not None: + usage["output_tokens"] = usage_output + if usage_cached is not None: + usage["cache_read_input_tokens"] = usage_cached + if usage_cache_creation is not None: + usage["cache_creation_input_tokens"] = usage_cache_creation + return usage + + +def _stream_usage_defaults(usage: dict[str, JsonValue]) -> dict[str, JsonValue]: + normalized = dict(usage) + if "input_tokens" not in normalized: + normalized["input_tokens"] = 0 + if "output_tokens" not in normalized: + normalized["output_tokens"] = 0 + if "cache_read_input_tokens" not in normalized: + normalized["cache_read_input_tokens"] = 0 + if "cache_creation_input_tokens" not in normalized: + normalized["cache_creation_input_tokens"] = 0 + return normalized + + +def _extract_usage_int(usage_source: Any, key: str) -> int | None: + if usage_source is None: + return None + if isinstance(usage_source, Mapping): + value = usage_source.get(key) + else: + value = getattr(usage_source, key, None) + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + return None + + +def _extract_stop_reason(result_message: Any) -> str: + stop_reason = getattr(result_message, "stop_reason", None) + if isinstance(stop_reason, str) and stop_reason: + return stop_reason + return "end_turn" + + +def _extract_result_error_message(result_message: Any) -> str: + result_text = _extract_result_text(result_message) + if result_text: + return result_text + return "Claude SDK returned an error result" + + +def _extract_result_text(result_message: Any) -> str | None: + result_value = getattr(result_message, "result", None) + if isinstance(result_value, str) and result_value.strip(): + return result_value.strip() + return None + + +def _to_sse(payload: dict[str, JsonValue]) -> str: + event_type = payload.get("type") + event_name = event_type if isinstance(event_type, str) else "message" + data = json.dumps(payload, ensure_ascii=True, separators=(",", ":")) + return f"event: {event_name}\ndata: {data}\n\n" + + +def _map_sdk_error(exc: Exception) -> AnthropicProxyError: + error_name = type(exc).__name__ + message = str(exc).strip() or error_name + if error_name in {"CLINotFoundError", "CLIConnectionError"}: + return AnthropicProxyError(503, anthropic_error_payload("api_error", message)) + if error_name in {"ProcessError", "CLIJSONDecodeError"}: + return AnthropicProxyError(502, anthropic_error_payload("api_error", message)) + if isinstance(exc, TimeoutError): + return AnthropicProxyError(504, anthropic_error_payload("api_error", "Claude SDK timeout")) + return AnthropicProxyError(500, anthropic_error_payload("api_error", message)) diff --git a/app/core/clients/anthropic_usage.py b/app/core/clients/anthropic_usage.py new file mode 100644 index 00000000..38a5e2a5 --- /dev/null +++ b/app/core/clients/anthropic_usage.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import cast + +import aiohttp + +from app.core.clients.http import get_http_client +from app.core.config.settings import get_settings + + +class AnthropicUsageFetchError(Exception): + def __init__(self, status_code: int, message: str) -> None: + super().__init__(message) + self.status_code = status_code + self.message = message + + +@dataclass(frozen=True, slots=True) +class AnthropicUsageWindow: + used_percent: float + reset_at_epoch: int | None + window_minutes: int + + +@dataclass(frozen=True, slots=True) +class AnthropicUsageSnapshot: + five_hour: AnthropicUsageWindow | None + seven_day: AnthropicUsageWindow | None + + +async def fetch_usage_snapshot( + *, + bearer_token: str, + base_url: str | None = None, + session: aiohttp.ClientSession | None = None, +) -> AnthropicUsageSnapshot: + settings = get_settings() + usage_base = (base_url or settings.anthropic_usage_base_url).rstrip("/") + url = f"{usage_base}/api/oauth/usage" + timeout = aiohttp.ClientTimeout(total=settings.usage_fetch_timeout_seconds) + headers = { + "Authorization": f"Bearer {bearer_token}", + "anthropic-beta": settings.anthropic_usage_beta, + "Content-Type": "application/json", + "Accept": "application/json", + } + client = session or get_http_client().session + try: + async with client.get(url, headers=headers, timeout=timeout) as resp: + payload = await _safe_json(resp) + if resp.status >= 400: + raise AnthropicUsageFetchError( + resp.status, + _error_message(payload) or f"Usage fetch failed ({resp.status})", + ) + except AnthropicUsageFetchError: + raise + except aiohttp.ClientError as exc: + raise AnthropicUsageFetchError(0, f"Usage fetch failed: {exc}") from exc + + return _parse_usage_payload(payload) + + +async def _safe_json(resp: aiohttp.ClientResponse) -> dict[str, object]: + try: + data = await resp.json(content_type=None) + except Exception: + text = await resp.text() + return {"error": {"message": text.strip()}} + return data if isinstance(data, dict) else {"error": {"message": str(data)}} + + +def _error_message(payload: dict[str, object]) -> str | None: + error = payload.get("error") + if isinstance(error, dict): + error_payload = cast(dict[str, object], error) + message = error_payload.get("message") + if isinstance(message, str) and message.strip(): + return message.strip() + message = payload.get("message") + if isinstance(message, str) and message.strip(): + return message.strip() + return None + + +def _parse_usage_payload(payload: dict[str, object]) -> AnthropicUsageSnapshot: + five_hour = _parse_usage_window(payload.get("five_hour"), window_minutes=300) + seven_day = _parse_usage_window(payload.get("seven_day"), window_minutes=10080) + return AnthropicUsageSnapshot(five_hour=five_hour, seven_day=seven_day) + + +def _parse_usage_window(raw: object, *, window_minutes: int) -> AnthropicUsageWindow | None: + if not isinstance(raw, dict): + return None + raw_payload = cast(dict[str, object], raw) + used = _normalize_utilization_percent(raw_payload.get("utilization")) + if used is None: + return None + reset_at_epoch = _parse_reset_at(raw_payload.get("resets_at")) + return AnthropicUsageWindow( + used_percent=used, + reset_at_epoch=reset_at_epoch, + window_minutes=window_minutes, + ) + + +def _normalize_utilization_percent(value: object) -> float | None: + if isinstance(value, (int, float)): + numeric = float(value) + if numeric <= 1.0: + numeric *= 100.0 + return max(0.0, min(100.0, numeric)) + return None + + +def _parse_reset_at(value: object) -> int | None: + if isinstance(value, str): + parsed = _parse_iso8601(value) + if parsed is not None: + return int(parsed.timestamp()) + return None + if isinstance(value, (int, float)): + # Some payloads may send epoch seconds. + return int(value) + return None + + +def _parse_iso8601(value: str) -> datetime | None: + try: + normalized = value.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) diff --git a/app/core/config/settings.py b/app/core/config/settings.py index c34d59b8..5e5c02e3 100644 --- a/app/core/config/settings.py +++ b/app/core/config/settings.py @@ -2,7 +2,7 @@ from functools import lru_cache from pathlib import Path -from typing import Annotated +from typing import Annotated, Literal from pydantic import Field, field_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict @@ -53,6 +53,32 @@ class Settings(BaseSettings): upstream_connect_timeout_seconds: float = 30.0 stream_idle_timeout_seconds: float = 300.0 max_sse_event_bytes: int = Field(default=2 * 1024 * 1024, gt=0) + anthropic_api_base_url: str = "https://api.anthropic.com" + anthropic_api_version: str = "2023-06-01" + anthropic_api_beta: str | None = None + anthropic_api_timeout_seconds: float = 300.0 + anthropic_api_detect_cli_headers: bool = True + anthropic_api_system_prompt_injection_mode: Literal["none", "minimal", "full"] = "minimal" + anthropic_oauth_token_url: str = "https://console.anthropic.com/v1/oauth/token" + anthropic_oauth_client_id: str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + anthropic_sdk_cli_path: str | None = None + anthropic_sdk_default_session_id: str | None = None + anthropic_sdk_pool_enabled: bool = False + anthropic_sdk_pool_size: int = Field(default=4, ge=1) + anthropic_sdk_pool_acquire_timeout_seconds: float = 30.0 + anthropic_usage_base_url: str = "https://api.anthropic.com" + anthropic_usage_beta: str = "oauth-2025-04-20" + anthropic_usage_refresh_enabled: bool = True + anthropic_usage_bearer_token: str | None = None + anthropic_org_id: str | None = None + anthropic_auto_discover_org: bool = False + anthropic_credentials_discovery_enabled: bool = True + anthropic_credentials_file: Path | None = None + anthropic_credentials_helper_command: str | None = None + anthropic_credentials_cache_seconds: int = Field(default=60, ge=0) + anthropic_default_account_id: str = "anthropic_default" + anthropic_default_account_email: str = "anthropic@local" + anthropic_default_plan_type: str = "pro" auth_base_url: str = "https://auth.openai.com" oauth_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann" oauth_scope: str = "openid profile email" @@ -97,6 +123,20 @@ def _expand_encryption_key_file(cls, value: str | Path) -> Path: return Path(value).expanduser() raise TypeError("encryption_key_file must be a path") + @field_validator("anthropic_credentials_file", mode="before") + @classmethod + def _expand_anthropic_credentials_file(cls, value: str | Path | None) -> Path | None: + if value is None: + return None + if isinstance(value, Path): + return value.expanduser() + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + return Path(stripped).expanduser() + raise TypeError("anthropic_credentials_file must be a path") + @field_validator("image_inline_allowed_hosts", mode="before") @classmethod def _normalize_image_inline_allowed_hosts(cls, value: object) -> list[str]: diff --git a/app/core/errors.py b/app/core/errors.py index e31789bb..c9841dcc 100644 --- a/app/core/errors.py +++ b/app/core/errors.py @@ -18,6 +18,16 @@ class OpenAIErrorEnvelope(TypedDict): error: OpenAIErrorDetail +class AnthropicErrorDetail(TypedDict): + type: str + message: str + + +class AnthropicErrorEnvelope(TypedDict): + type: str + error: AnthropicErrorDetail + + class DashboardErrorDetail(TypedDict): code: str message: str @@ -45,6 +55,10 @@ def openai_error(code: str, message: str, error_type: str = "server_error") -> O return {"error": {"message": message, "type": error_type, "code": code}} +def anthropic_error(error_type: str, message: str) -> AnthropicErrorEnvelope: + return {"type": "error", "error": {"type": error_type, "message": message}} + + def dashboard_error(code: str, message: str) -> DashboardErrorEnvelope: return {"error": {"code": code, "message": message}} diff --git a/app/core/handlers/exceptions.py b/app/core/handlers/exceptions.py index df0e27bb..5923de5c 100644 --- a/app/core/handlers/exceptions.py +++ b/app/core/handlers/exceptions.py @@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse, Response from starlette.exceptions import HTTPException as StarletteHTTPException -from app.core.errors import dashboard_error, openai_error +from app.core.errors import anthropic_error, dashboard_error, openai_error from app.core.exceptions import ( AppError, DashboardAuthError, @@ -109,6 +109,11 @@ async def validation_error_handler( if param: error["error"]["param"] = param return JSONResponse(status_code=400, content=error) + if fmt == "anthropic": + return JSONResponse( + status_code=400, + content=anthropic_error("invalid_request_error", "Invalid request payload"), + ) return await request_validation_exception_handler(request, exc) @app.exception_handler(StarletteHTTPException) @@ -142,6 +147,21 @@ async def http_error_handler( error_type = "server_error" code = "server_error" return JSONResponse(status_code=exc.status_code, content=openai_error(code, detail, error_type=error_type)) + if fmt == "anthropic": + if isinstance(exc.detail, dict): + return JSONResponse(status_code=exc.status_code, content=exc.detail) + error_type = "api_error" + if exc.status_code == 400: + error_type = "invalid_request_error" + elif exc.status_code == 401: + error_type = "authentication_error" + elif exc.status_code == 403: + error_type = "permission_error" + elif exc.status_code == 404: + error_type = "not_found_error" + elif exc.status_code == 429: + error_type = "rate_limit_error" + return JSONResponse(status_code=exc.status_code, content=anthropic_error(error_type, detail)) return await http_exception_handler(request, exc) # --- Catch-all for unhandled exceptions --- @@ -160,4 +180,9 @@ async def unhandled_error_handler(request: Request, exc: Exception) -> JSONRespo status_code=500, content=openai_error("server_error", "Internal server error", error_type="server_error"), ) + if fmt == "anthropic": + return JSONResponse( + status_code=500, + content=anthropic_error("api_error", "Internal server error"), + ) return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) diff --git a/app/core/openai/model_refresh_scheduler.py b/app/core/openai/model_refresh_scheduler.py index ce323aa4..8607ad96 100644 --- a/app/core/openai/model_refresh_scheduler.py +++ b/app/core/openai/model_refresh_scheduler.py @@ -93,6 +93,8 @@ def _group_by_plan(accounts: list[Account]) -> dict[str, list[Account]]: for account in accounts: if account.status != AccountStatus.ACTIVE: continue + if not account.chatgpt_account_id: + continue plan_type = account.plan_type if not plan_type: continue @@ -115,7 +117,7 @@ async def _fetch_with_failover( except ModelFetchError as exc: if exc.status_code == 401: try: - account = await auth_manager.ensure_fresh(account, force=True) + account = await AuthManager(accounts_repo).ensure_fresh(account, force=True) access_token = encryptor.decrypt(account.access_token_encrypted) return await fetch_models_for_plan(access_token, account.chatgpt_account_id) except (ModelFetchError, RefreshError): diff --git a/app/core/usage/pricing.py b/app/core/usage/pricing.py index 7ebca28c..947bd8f6 100644 --- a/app/core/usage/pricing.py +++ b/app/core/usage/pricing.py @@ -74,6 +74,17 @@ def _normalize_usage(usage: UsageTokens | ResponseUsage | None) -> UsageTokens | ), "gpt-5.1-codex": ModelPrice(input_per_1m=1.25, cached_input_per_1m=0.125, output_per_1m=10.0), "gpt-5-codex": ModelPrice(input_per_1m=1.25, cached_input_per_1m=0.125, output_per_1m=10.0), + "claude-opus-4.6": ModelPrice(input_per_1m=5.0, cached_input_per_1m=0.5, output_per_1m=25.0), + "claude-opus-4.5": ModelPrice(input_per_1m=5.0, cached_input_per_1m=0.5, output_per_1m=25.0), + "claude-opus-4.1": ModelPrice(input_per_1m=15.0, cached_input_per_1m=1.5, output_per_1m=75.0), + "claude-opus-4": ModelPrice(input_per_1m=15.0, cached_input_per_1m=1.5, output_per_1m=75.0), + "claude-sonnet-4.6": ModelPrice(input_per_1m=3.0, cached_input_per_1m=0.3, output_per_1m=15.0), + "claude-sonnet-4.5": ModelPrice(input_per_1m=3.0, cached_input_per_1m=0.3, output_per_1m=15.0), + "claude-sonnet-4": ModelPrice(input_per_1m=3.0, cached_input_per_1m=0.3, output_per_1m=15.0), + "claude-3-7-sonnet": ModelPrice(input_per_1m=3.0, cached_input_per_1m=0.3, output_per_1m=15.0), + "claude-haiku-4.5": ModelPrice(input_per_1m=1.0, cached_input_per_1m=0.1, output_per_1m=5.0), + "claude-3-5-haiku": ModelPrice(input_per_1m=0.8, cached_input_per_1m=0.08, output_per_1m=4.0), + "claude-haiku-3": ModelPrice(input_per_1m=0.25, cached_input_per_1m=0.03, output_per_1m=1.25), } DEFAULT_MODEL_ALIASES: dict[str, str] = { @@ -85,6 +96,54 @@ def _normalize_usage(usage: UsageTokens | ResponseUsage | None) -> UsageTokens | "gpt-5.1-codex-mini*": "gpt-5.1-codex-mini", "gpt-5.1-codex*": "gpt-5.1-codex", "gpt-5-codex*": "gpt-5-codex", + "claude-opus-4.6*": "claude-opus-4.6", + "claude-opus-4-6*": "claude-opus-4.6", + "claude-opus-4.5*": "claude-opus-4.5", + "claude-opus-4-5*": "claude-opus-4.5", + "claude-opus-4.1*": "claude-opus-4.1", + "claude-opus-4-1*": "claude-opus-4.1", + "claude-opus-4*": "claude-opus-4", + "claude-sonnet-4.6*": "claude-sonnet-4.6", + "claude-sonnet-4-6*": "claude-sonnet-4.6", + "claude-sonnet-4.5*": "claude-sonnet-4.5", + "claude-sonnet-4-5*": "claude-sonnet-4.5", + "claude-sonnet-4*": "claude-sonnet-4", + "claude-3-7-sonnet*": "claude-3-7-sonnet", + "claude-haiku-4.5*": "claude-haiku-4.5", + "claude-haiku-4-5*": "claude-haiku-4.5", + "claude-3-5-haiku*": "claude-3-5-haiku", + "anthropic/claude-opus-4.6*": "claude-opus-4.6", + "anthropic/claude-opus-4-6*": "claude-opus-4.6", + "anthropic/claude-opus-4.5*": "claude-opus-4.5", + "anthropic/claude-opus-4-5*": "claude-opus-4.5", + "anthropic/claude-opus-4.1*": "claude-opus-4.1", + "anthropic/claude-opus-4-1*": "claude-opus-4.1", + "anthropic/claude-opus-4*": "claude-opus-4", + "anthropic/claude-sonnet-4.6*": "claude-sonnet-4.6", + "anthropic/claude-sonnet-4-6*": "claude-sonnet-4.6", + "anthropic/claude-sonnet-4.5*": "claude-sonnet-4.5", + "anthropic/claude-sonnet-4-5*": "claude-sonnet-4.5", + "anthropic/claude-sonnet-4*": "claude-sonnet-4", + "anthropic/claude-3-7-sonnet*": "claude-3-7-sonnet", + "anthropic/claude-haiku-4.5*": "claude-haiku-4.5", + "anthropic/claude-haiku-4-5*": "claude-haiku-4.5", + "anthropic/claude-3-5-haiku*": "claude-3-5-haiku", + "anthropic2/claude-opus-4.6*": "claude-opus-4.6", + "anthropic2/claude-opus-4-6*": "claude-opus-4.6", + "anthropic2/claude-opus-4.5*": "claude-opus-4.5", + "anthropic2/claude-opus-4-5*": "claude-opus-4.5", + "anthropic2/claude-opus-4.1*": "claude-opus-4.1", + "anthropic2/claude-opus-4-1*": "claude-opus-4.1", + "anthropic2/claude-opus-4*": "claude-opus-4", + "anthropic2/claude-sonnet-4.6*": "claude-sonnet-4.6", + "anthropic2/claude-sonnet-4-6*": "claude-sonnet-4.6", + "anthropic2/claude-sonnet-4.5*": "claude-sonnet-4.5", + "anthropic2/claude-sonnet-4-5*": "claude-sonnet-4.5", + "anthropic2/claude-sonnet-4*": "claude-sonnet-4", + "anthropic2/claude-3-7-sonnet*": "claude-3-7-sonnet", + "anthropic2/claude-haiku-4.5*": "claude-haiku-4.5", + "anthropic2/claude-haiku-4-5*": "claude-haiku-4.5", + "anthropic2/claude-3-5-haiku*": "claude-3-5-haiku", } diff --git a/app/core/usage/refresh_scheduler.py b/app/core/usage/refresh_scheduler.py index 3673c0ec..86edd005 100644 --- a/app/core/usage/refresh_scheduler.py +++ b/app/core/usage/refresh_scheduler.py @@ -8,6 +8,8 @@ from app.core.config.settings import get_settings from app.db.session import get_background_session from app.modules.accounts.repository import AccountsRepository +from app.modules.anthropic.repository import AnthropicRepository +from app.modules.anthropic.service import AnthropicService from app.modules.proxy.rate_limit_cache import get_rate_limit_headers_cache from app.modules.usage.repository import UsageRepository from app.modules.usage.updater import UsageUpdater @@ -52,12 +54,19 @@ async def _refresh_once(self) -> None: async with self._lock: try: async with get_background_session() as session: - usage_repo = UsageRepository(session) - accounts_repo = AccountsRepository(session) - latest_usage = await usage_repo.latest_by_account(window="primary") - accounts = await accounts_repo.list_accounts() - updater = UsageUpdater(usage_repo, accounts_repo) - await updater.refresh_accounts(accounts, latest_usage) + settings = get_settings() + if settings.usage_refresh_enabled: + usage_repo = UsageRepository(session) + accounts_repo = AccountsRepository(session) + latest_usage = await usage_repo.latest_by_account(window="primary") + accounts = await accounts_repo.list_accounts() + updater = UsageUpdater(usage_repo, accounts_repo) + await updater.refresh_accounts(accounts, latest_usage) + + if settings.anthropic_usage_refresh_enabled: + service = AnthropicService(AnthropicRepository(session)) + await service.refresh_usage_windows() + await get_rate_limit_headers_cache().invalidate() except Exception: logger.exception("Usage refresh loop failed") @@ -65,7 +74,8 @@ async def _refresh_once(self) -> None: def build_usage_refresh_scheduler() -> UsageRefreshScheduler: settings = get_settings() + enabled = settings.usage_refresh_enabled or settings.anthropic_usage_refresh_enabled return UsageRefreshScheduler( interval_seconds=settings.usage_refresh_interval_seconds, - enabled=settings.usage_refresh_enabled, + enabled=enabled, ) diff --git a/app/dependencies.py b/app/dependencies.py index 3fb6adef..f414a9b1 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -10,6 +10,8 @@ from app.db.session import get_background_session, get_session from app.modules.accounts.repository import AccountsRepository from app.modules.accounts.service import AccountsService +from app.modules.anthropic.repository import AnthropicRepository +from app.modules.anthropic.service import AnthropicService from app.modules.api_keys.repository import ApiKeysRepository from app.modules.api_keys.service import ApiKeysService from app.modules.dashboard.repository import DashboardRepository @@ -59,6 +61,11 @@ class ProxyContext: service: ProxyService +@dataclass(slots=True) +class AnthropicContext: + service: AnthropicService + + @dataclass(slots=True) class ApiKeysContext: session: AsyncSession @@ -159,6 +166,14 @@ def get_proxy_context(request: Request) -> ProxyContext: return ProxyContext(service=service) +def get_anthropic_context( + session: AsyncSession = Depends(get_session), +) -> AnthropicContext: + repository = AnthropicRepository(session) + service = AnthropicService(repository) + return AnthropicContext(service=service) + + def get_api_keys_context( session: AsyncSession = Depends(get_session), ) -> ApiKeysContext: diff --git a/app/main.py b/app/main.py index 9a824f60..e643b234 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse +from app.core.clients.anthropic_proxy import close_anthropic_client_pools from app.core.clients.http import close_http_client, init_http_client from app.core.config.settings_cache import get_settings_cache from app.core.handlers import add_exception_handlers @@ -17,6 +18,7 @@ from app.core.usage.refresh_scheduler import build_usage_refresh_scheduler from app.db.session import close_db, init_db from app.modules.accounts import api as accounts_api +from app.modules.anthropic import api as anthropic_api from app.modules.api_keys import api as api_keys_api from app.modules.dashboard import api as dashboard_api from app.modules.dashboard_auth import api as dashboard_auth_api @@ -46,6 +48,7 @@ async def lifespan(_: FastAPI): await model_scheduler.stop() await usage_scheduler.stop() try: + await close_anthropic_client_pools() await close_http_client() finally: await close_db() @@ -65,6 +68,9 @@ def create_app() -> FastAPI: app.include_router(proxy_api.router) app.include_router(proxy_api.v1_router) + app.include_router(anthropic_api.router) + app.include_router(anthropic_api.api_router) + app.include_router(anthropic_api.diagnostics_router) app.include_router(proxy_api.usage_router) app.include_router(accounts_api.router) app.include_router(dashboard_api.router) diff --git a/app/modules/accounts/api.py b/app/modules/accounts/api.py index c75fdbcf..09460083 100644 --- a/app/modules/accounts/api.py +++ b/app/modules/accounts/api.py @@ -1,6 +1,6 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, File, UploadFile +from fastapi import APIRouter, Depends, File, Form, UploadFile from app.core.auth.dependencies import set_dashboard_error_format, validate_dashboard_session from app.core.exceptions import DashboardBadRequestError, DashboardConflictError, DashboardNotFoundError @@ -14,7 +14,11 @@ AccountsResponse, AccountTrendsResponse, ) -from app.modules.accounts.service import InvalidAuthJsonError +from app.modules.accounts.service import ( + InvalidAnthropicAuthJsonError, + InvalidAnthropicEmailError, + InvalidAuthJsonError, +) router = APIRouter( prefix="/api/accounts", @@ -56,6 +60,24 @@ async def import_account( raise DashboardConflictError(str(exc), code="duplicate_identity_conflict") from exc +@router.post("/import-anthropic", response_model=AccountImportResponse) +async def import_anthropic_account( + credentials_json: UploadFile = File(...), + email: str = Form(...), + context: AccountsContext = Depends(get_accounts_context), +) -> AccountImportResponse: + raw = await credentials_json.read() + try: + return await context.service.import_anthropic_account(raw, email=email) + except InvalidAnthropicAuthJsonError as exc: + raise DashboardBadRequestError( + "Invalid Anthropic credentials payload", + code="invalid_anthropic_auth_json", + ) from exc + except InvalidAnthropicEmailError as exc: + raise DashboardBadRequestError(str(exc), code="invalid_anthropic_email") from exc + + @router.post("/{account_id}/reactivate", response_model=AccountReactivateResponse) async def reactivate_account( account_id: str, diff --git a/app/modules/accounts/service.py b/app/modules/accounts/service.py index 2aa25552..1125f457 100644 --- a/app/modules/accounts/service.py +++ b/app/modules/accounts/service.py @@ -12,6 +12,8 @@ generate_unique_account_id, parse_auth_json, ) +from app.core.auth.anthropic_credentials import parse_anthropic_auth_json +from app.core.config.settings import get_settings from app.core.crypto import TokenEncryptor from app.core.plan_types import coerce_account_plan_type from app.core.utils.time import naive_utc_to_epoch, to_utc_naive, utcnow @@ -34,6 +36,14 @@ class InvalidAuthJsonError(Exception): pass +class InvalidAnthropicAuthJsonError(Exception): + pass + + +class InvalidAnthropicEmailError(Exception): + pass + + class AccountsService: def __init__( self, @@ -117,6 +127,61 @@ async def import_account(self, raw: bytes) -> AccountImportResponse: status=saved.status, ) + async def import_anthropic_account(self, raw: bytes, *, email: str) -> AccountImportResponse: + try: + auth = parse_anthropic_auth_json(raw) + except (json.JSONDecodeError, TypeError, UnicodeDecodeError, ValueError) as exc: + raise InvalidAnthropicAuthJsonError("Invalid Anthropic credential payload") from exc + + normalized_email = email.strip().lower() + if "@" not in normalized_email: + raise InvalidAnthropicEmailError("Invalid Anthropic account email") + + settings = get_settings() + account_id = settings.anthropic_default_account_id + plan_type = coerce_account_plan_type(settings.anthropic_default_plan_type, DEFAULT_PLAN) + + refresh_token = auth.refresh_token or "" + encrypted_access = self._encryptor.encrypt(auth.access_token) + encrypted_refresh = self._encryptor.encrypt(refresh_token) + encrypted_id = self._encryptor.encrypt("") + + existing = await self._repo.get_by_id(account_id) + if existing is None: + account = Account( + id=account_id, + chatgpt_account_id=None, + email=normalized_email, + plan_type=plan_type, + access_token_encrypted=encrypted_access, + refresh_token_encrypted=encrypted_refresh, + id_token_encrypted=encrypted_id, + last_refresh=utcnow(), + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + saved = await self._repo.upsert(account, merge_by_email=False) + else: + await self._repo.update_tokens( + account_id=account_id, + access_token_encrypted=encrypted_access, + refresh_token_encrypted=encrypted_refresh, + id_token_encrypted=encrypted_id, + last_refresh=utcnow(), + plan_type=plan_type, + email=normalized_email, + ) + await self._repo.update_status(account_id, AccountStatus.ACTIVE, None) + saved = await self._repo.get_by_id(account_id) + if saved is None: + raise RuntimeError("Failed to load saved Anthropic account") + return AccountImportResponse( + account_id=saved.id, + email=saved.email, + plan_type=saved.plan_type, + status=saved.status, + ) + async def reactivate_account(self, account_id: str) -> bool: return await self._repo.update_status(account_id, AccountStatus.ACTIVE, None) diff --git a/app/modules/anthropic/__init__.py b/app/modules/anthropic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/modules/anthropic/api.py b/app/modules/anthropic/api.py new file mode 100644 index 00000000..e7604003 --- /dev/null +++ b/app/modules/anthropic/api.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Literal + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Security +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.core.auth.dependencies import ( + set_anthropic_error_format, + set_dashboard_error_format, + validate_dashboard_session, +) +from app.core.clients.anthropic_api_proxy import get_recent_diagnostics +from app.core.clients.anthropic_proxy import AnthropicProxyError, anthropic_error_payload +from app.core.config.settings_cache import get_settings_cache +from app.core.types import JsonValue +from app.db.session import get_background_session +from app.dependencies import AnthropicContext, get_anthropic_context +from app.modules.api_keys.repository import ApiKeysRepository +from app.modules.api_keys.service import ( + ApiKeyData, + ApiKeyInvalidError, + ApiKeyRateLimitExceededError, + ApiKeysService, + ApiKeyUsageReservationData, +) + +router = APIRouter(prefix="/claude/v1", tags=["anthropic"], dependencies=[Depends(set_anthropic_error_format)]) +api_router = APIRouter(prefix="/claude-sdk/v1", tags=["anthropic"], dependencies=[Depends(set_anthropic_error_format)]) +diagnostics_router = APIRouter( + prefix="/api/anthropic", + tags=["dashboard"], + dependencies=[Depends(validate_dashboard_session), Depends(set_dashboard_error_format)], +) + +_bearer = HTTPBearer(description="API key (e.g. sk-clb-...)", auto_error=False) + + +async def validate_anthropic_api_key( + credentials: HTTPAuthorizationCredentials | None = Security(_bearer), +) -> ApiKeyData | None: + settings = await get_settings_cache().get() + if not settings.api_key_auth_enabled: + return None + + if credentials is None: + raise HTTPException( + status_code=401, + detail=anthropic_error_payload("authentication_error", "Missing API key in Authorization header"), + ) + + token = credentials.credentials + async with get_background_session() as session: + service = ApiKeysService(ApiKeysRepository(session)) + try: + return await service.validate_key(token) + except ApiKeyInvalidError as exc: + raise HTTPException( + status_code=401, + detail=anthropic_error_payload("authentication_error", str(exc)), + ) from exc + + +@router.post("/messages") +async def messages( + request: Request, + context: AnthropicContext = Depends(get_anthropic_context), + api_key: ApiKeyData | None = Security(validate_anthropic_api_key), +): + return await _messages_impl(request, context, api_key, transport="api") + + +@api_router.post("/messages") +async def messages_api( + request: Request, + context: AnthropicContext = Depends(get_anthropic_context), + api_key: ApiKeyData | None = Security(validate_anthropic_api_key), +): + return await _messages_impl(request, context, api_key, transport="sdk") + + +async def _messages_impl( + request: Request, + context: AnthropicContext, + api_key: ApiKeyData | None, + *, + transport: Literal["sdk", "api"], +): + payload = await _require_json_object(request) + model = _extract_model(payload) + _validate_model_access(api_key, model) + reservation = await _enforce_request_limits(api_key, request_model=model) + stream = bool(payload.get("stream")) + + if stream: + upstream_stream = context.service.stream_messages( + payload, + request.headers, + api_key=api_key, + api_key_reservation=reservation, + transport=transport, + ) + try: + first = await upstream_stream.__anext__() + except StopAsyncIteration: + return StreamingResponse( + _prepend_first(None, upstream_stream), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache"}, + ) + except AnthropicProxyError as exc: + return JSONResponse(status_code=exc.status_code, content=exc.payload) + + return StreamingResponse( + _prepend_first(first, upstream_stream), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache"}, + ) + + try: + response_payload = await context.service.create_message( + payload, + request.headers, + api_key=api_key, + api_key_reservation=reservation, + transport=transport, + ) + return JSONResponse(content=response_payload) + except AnthropicProxyError as exc: + return JSONResponse(status_code=exc.status_code, content=exc.payload) + + +async def _enforce_request_limits( + api_key: ApiKeyData | None, + *, + request_model: str | None, +) -> ApiKeyUsageReservationData | None: + if api_key is None: + return None + + async with get_background_session() as session: + service = ApiKeysService(ApiKeysRepository(session)) + try: + return await service.enforce_limits_for_request(api_key.id, request_model=request_model) + except ApiKeyRateLimitExceededError as exc: + message = f"{exc}. Usage resets at {exc.reset_at.isoformat()}Z." + raise HTTPException( + status_code=429, + detail=anthropic_error_payload("rate_limit_error", message), + ) from exc + except ApiKeyInvalidError as exc: + raise HTTPException( + status_code=401, + detail=anthropic_error_payload("authentication_error", str(exc)), + ) from exc + + +def _validate_model_access(api_key: ApiKeyData | None, model: str | None) -> None: + if api_key is None: + return + allowed_models = api_key.allowed_models + if not allowed_models: + return + if model is None or model in allowed_models: + return + message = f"This API key does not have access to model '{model}'" + raise HTTPException( + status_code=403, + detail=anthropic_error_payload("permission_error", message), + ) + + +async def _prepend_first(first: str | None, stream: AsyncIterator[str]) -> AsyncIterator[str]: + if first is not None: + yield first + async for line in stream: + yield line + + +async def _require_json_object(request: Request) -> dict[str, JsonValue]: + try: + payload = await request.json() + except Exception as exc: + raise HTTPException( + status_code=400, + detail=anthropic_error_payload("invalid_request_error", "Invalid JSON body"), + ) from exc + if not isinstance(payload, dict): + raise HTTPException( + status_code=400, + detail=anthropic_error_payload("invalid_request_error", "Request body must be an object"), + ) + return payload + + +def _extract_model(payload: dict[str, JsonValue]) -> str | None: + model = payload.get("model") + if isinstance(model, str) and model.strip(): + return model.strip() + return None + + +@diagnostics_router.get("/diagnostics") +async def list_anthropic_diagnostics( + limit: int = Query(100, ge=1, le=500), +) -> dict[str, list[dict[str, object]]]: + return {"entries": get_recent_diagnostics(limit)} diff --git a/app/modules/anthropic/repository.py b/app/modules/anthropic/repository.py new file mode 100644 index 00000000..af78a4b1 --- /dev/null +++ b/app/modules/anthropic/repository.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.crypto import TokenEncryptor +from app.core.utils.time import utcnow +from app.db.models import Account, AccountStatus, RequestLog, UsageHistory +from app.modules.accounts.repository import AccountsRepository +from app.modules.request_logs.repository import RequestLogsRepository +from app.modules.usage.repository import UsageRepository + +_PLACEHOLDER_TOKEN = "anthropic-provider" + + +class AnthropicRepository: + def __init__(self, session: AsyncSession) -> None: + self._accounts = AccountsRepository(session) + self._request_logs = RequestLogsRepository(session) + self._usage = UsageRepository(session) + self._encryptor = TokenEncryptor() + + async def ensure_provider_account( + self, + *, + account_id: str, + email: str, + plan_type: str, + ) -> Account: + existing = await self._accounts.get_by_id(account_id) + if existing is not None: + if existing.status != AccountStatus.ACTIVE: + await self._accounts.update_status(account_id, AccountStatus.ACTIVE, None) + existing.status = AccountStatus.ACTIVE + existing.deactivation_reason = None + return existing + + now = utcnow() + encrypted = self._encryptor.encrypt(_PLACEHOLDER_TOKEN) + created = Account( + id=account_id, + chatgpt_account_id=None, + email=email, + plan_type=plan_type, + access_token_encrypted=encrypted, + refresh_token_encrypted=encrypted, + id_token_encrypted=encrypted, + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + return await self._accounts.upsert(created, merge_by_email=False) + + async def add_request_log( + self, + *, + account_id: str, + request_id: str, + model: str, + input_tokens: int | None, + output_tokens: int | None, + cached_input_tokens: int | None, + latency_ms: int | None, + status: str, + error_code: str | None, + error_message: str | None, + api_key_id: str | None, + requested_at: datetime | None = None, + ) -> RequestLog: + return await self._request_logs.add_log( + account_id=account_id, + request_id=request_id, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_input_tokens=cached_input_tokens, + latency_ms=latency_ms, + status=status, + error_code=error_code, + error_message=error_message, + api_key_id=api_key_id, + requested_at=requested_at, + ) + + async def add_usage_entry( + self, + *, + account_id: str, + used_percent: float, + window: str, + reset_at: int | None, + window_minutes: int, + ) -> UsageHistory: + entry = await self._usage.add_entry( + account_id=account_id, + used_percent=used_percent, + window=window, + reset_at=reset_at, + window_minutes=window_minutes, + ) + return entry diff --git a/app/modules/anthropic/schemas.py b/app/modules/anthropic/schemas.py new file mode 100644 index 00000000..36061993 --- /dev/null +++ b/app/modules/anthropic/schemas.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class AnthropicErrorDetail(BaseModel): + model_config = ConfigDict(extra="allow") + + type: str + message: str + + +class AnthropicErrorEnvelope(BaseModel): + model_config = ConfigDict(extra="allow") + + type: str = "error" + error: AnthropicErrorDetail diff --git a/app/modules/anthropic/service.py b/app/modules/anthropic/service.py new file mode 100644 index 00000000..e3aeba33 --- /dev/null +++ b/app/modules/anthropic/service.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass +from typing import Literal + +import anyio + +from app.core.auth.anthropic_credentials import credentials_from_account as anthropic_credentials_from_account +from app.core.auth.anthropic_credentials import resolve_anthropic_credentials +from app.core.clients.anthropic_api_proxy import ( + create_message as core_create_message_api, +) +from app.core.clients.anthropic_api_proxy import ( + stream_messages as core_stream_messages_api, +) +from app.core.clients.anthropic_proxy import ( + AnthropicProxyError, + parse_sse_data_payload, +) +from app.core.clients.anthropic_proxy import ( + create_message as core_create_message, +) +from app.core.clients.anthropic_proxy import ( + stream_messages as core_stream_messages, +) +from app.core.clients.anthropic_usage import AnthropicUsageFetchError, fetch_usage_snapshot +from app.core.config.settings import get_settings +from app.core.types import JsonValue +from app.core.utils.request_id import ensure_request_id, get_request_id +from app.db.models import Account +from app.modules.anthropic.repository import AnthropicRepository +from app.modules.api_keys.repository import ApiKeysRepository +from app.modules.api_keys.service import ApiKeyData, ApiKeysService, ApiKeyUsageReservationData + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class AnthropicRequestUsage: + input_tokens: int | None + output_tokens: int | None + cached_input_tokens: int | None + + +@dataclass(frozen=True, slots=True) +class AnthropicRequestError: + code: str | None + message: str | None + + +class AnthropicService: + def __init__(self, repository: AnthropicRepository) -> None: + self._repository = repository + + async def create_message( + self, + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + api_key: ApiKeyData | None, + api_key_reservation: ApiKeyUsageReservationData | None, + transport: Literal["sdk", "api"] = "sdk", + ) -> dict[str, JsonValue]: + settings = get_settings() + request_id = ensure_request_id(headers.get("x-request-id") or headers.get("request-id")) + account = await self._repository.ensure_provider_account( + account_id=settings.anthropic_default_account_id, + email=settings.anthropic_default_account_email, + plan_type=settings.anthropic_default_plan_type, + ) + model = _extract_request_model(payload) + start = time.monotonic() + status = "success" + usage = AnthropicRequestUsage(input_tokens=None, output_tokens=None, cached_input_tokens=None) + error = AnthropicRequestError(code=None, message=None) + + try: + response_payload = await _create_message_with_transport( + transport, + payload, + headers, + account=account, + ) + model = _extract_response_model(response_payload) or model + usage = _usage_from_message_payload(response_payload) + return response_payload + except AnthropicProxyError as exc: + status = "error" + error = _extract_error(exc.payload) + raise + finally: + latency_ms = int((time.monotonic() - start) * 1000) + await self._persist_request_log( + account_id=account.id, + api_key_id=api_key.id if api_key else None, + request_id=request_id, + model=model, + usage=usage, + latency_ms=latency_ms, + status=status, + error=error, + ) + await self._settle_reservation( + api_key=api_key, + reservation=api_key_reservation, + model=model, + status=status, + usage=usage, + ) + + def stream_messages( + self, + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + api_key: ApiKeyData | None, + api_key_reservation: ApiKeyUsageReservationData | None, + transport: Literal["sdk", "api"] = "sdk", + ) -> AsyncIterator[str]: + return self._stream_messages( + payload, + headers, + api_key=api_key, + api_key_reservation=api_key_reservation, + transport=transport, + ) + + async def refresh_usage_windows(self) -> bool: + settings = get_settings() + if not settings.anthropic_usage_refresh_enabled: + return False + + credentials = await resolve_anthropic_credentials() + if credentials is None: + return False + + account = await self._repository.ensure_provider_account( + account_id=settings.anthropic_default_account_id, + email=settings.anthropic_default_account_email, + plan_type=settings.anthropic_default_plan_type, + ) + + try: + snapshot = await fetch_usage_snapshot( + bearer_token=credentials.bearer_token, + ) + except AnthropicUsageFetchError as exc: + logger.warning( + "anthropic_usage_refresh_failed status=%s message=%s request_id=%s", + exc.status_code, + exc.message, + get_request_id(), + ) + return False + + wrote = False + if snapshot.five_hour is not None: + await self._repository.add_usage_entry( + account_id=account.id, + used_percent=snapshot.five_hour.used_percent, + window="primary", + reset_at=snapshot.five_hour.reset_at_epoch, + window_minutes=snapshot.five_hour.window_minutes, + ) + wrote = True + + if snapshot.seven_day is not None: + await self._repository.add_usage_entry( + account_id=account.id, + used_percent=snapshot.seven_day.used_percent, + window="secondary", + reset_at=snapshot.seven_day.reset_at_epoch, + window_minutes=snapshot.seven_day.window_minutes, + ) + wrote = True + + return wrote + + async def _stream_messages( + self, + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + api_key: ApiKeyData | None, + api_key_reservation: ApiKeyUsageReservationData | None, + transport: Literal["sdk", "api"], + ) -> AsyncIterator[str]: + settings = get_settings() + request_id = ensure_request_id(headers.get("x-request-id") or headers.get("request-id")) + account = await self._repository.ensure_provider_account( + account_id=settings.anthropic_default_account_id, + email=settings.anthropic_default_account_email, + plan_type=settings.anthropic_default_plan_type, + ) + model = _extract_request_model(payload) + start = time.monotonic() + accumulator = _StreamAccumulator(model=model) + + try: + async for line in _stream_messages_with_transport( + transport, + payload, + headers, + account=account, + ): + event_payload = parse_sse_data_payload(line) + accumulator.observe(event_payload) + yield line + accumulator.mark_stream_end() + except AnthropicProxyError as exc: + accumulator.observe(exc.payload) + accumulator.mark_error_from_payload(exc.payload) + raise + finally: + usage = accumulator.to_usage() + error = accumulator.to_error() + status = accumulator.status + latency_ms = int((time.monotonic() - start) * 1000) + await self._persist_request_log( + account_id=account.id, + api_key_id=api_key.id if api_key else None, + request_id=request_id, + model=accumulator.model, + usage=usage, + latency_ms=latency_ms, + status=status, + error=error, + ) + await self._settle_reservation( + api_key=api_key, + reservation=api_key_reservation, + model=accumulator.model, + status=status, + usage=usage, + ) + + async def _persist_request_log( + self, + *, + account_id: str, + api_key_id: str | None, + request_id: str, + model: str, + usage: AnthropicRequestUsage, + latency_ms: int, + status: str, + error: AnthropicRequestError, + ) -> None: + with anyio.CancelScope(shield=True): + try: + await self._repository.add_request_log( + account_id=account_id, + api_key_id=api_key_id, + request_id=request_id, + model=model, + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + cached_input_tokens=usage.cached_input_tokens, + latency_ms=latency_ms, + status=status, + error_code=error.code, + error_message=error.message, + ) + except Exception: + logger.warning( + "anthropic_request_log_persist_failed request_id=%s account_id=%s", + request_id, + account_id, + exc_info=True, + ) + + async def _settle_reservation( + self, + *, + api_key: ApiKeyData | None, + reservation: ApiKeyUsageReservationData | None, + model: str, + status: str, + usage: AnthropicRequestUsage, + ) -> None: + if api_key is None or reservation is None: + return + + with anyio.CancelScope(shield=True): + try: + from app.db.session import get_background_session + + async with get_background_session() as session: + api_keys_service = ApiKeysService(ApiKeysRepository(session)) + if status == "success" and usage.input_tokens is not None and usage.output_tokens is not None: + await api_keys_service.finalize_usage_reservation( + reservation.reservation_id, + model=model, + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + cached_input_tokens=usage.cached_input_tokens or 0, + ) + else: + await api_keys_service.release_usage_reservation(reservation.reservation_id) + except Exception: + logger.warning( + "anthropic_reservation_settlement_failed request_id=%s key_id=%s", + get_request_id(), + api_key.id, + exc_info=True, + ) + + +@dataclass(slots=True) +class _StreamAccumulator: + model: str + status: str = "success" + error_code: str | None = None + error_message: str | None = None + input_tokens: int | None = None + output_tokens: int | None = None + cached_input_tokens: int | None = None + saw_terminal: bool = False + + def observe(self, payload: dict[str, JsonValue] | None) -> None: + if payload is None: + return + payload_type = payload.get("type") + if not isinstance(payload_type, str): + return + + if payload_type == "message_start": + message = payload.get("message") + if isinstance(message, dict): + model_value = message.get("model") + if isinstance(model_value, str) and model_value.strip(): + self.model = model_value.strip() + usage = message.get("usage") + self._apply_usage(usage if isinstance(usage, dict) else None) + return + + if payload_type == "message_delta": + usage = payload.get("usage") + if isinstance(usage, dict): + self._apply_usage(usage) + return + + if payload_type == "message_stop": + self.saw_terminal = True + return + + if payload_type == "error": + self.mark_error_from_payload(payload) + + def mark_error_from_payload(self, payload: dict[str, JsonValue]) -> None: + self.status = "error" + error = _extract_error(payload) + self.error_code = error.code + self.error_message = error.message + + def mark_stream_end(self) -> None: + if self.status == "success" and not self.saw_terminal: + self.status = "error" + self.error_code = "stream_incomplete" + self.error_message = "Upstream closed stream without message_stop" + + def to_usage(self) -> AnthropicRequestUsage: + return AnthropicRequestUsage( + input_tokens=self.input_tokens, + output_tokens=self.output_tokens, + cached_input_tokens=self.cached_input_tokens, + ) + + def to_error(self) -> AnthropicRequestError: + return AnthropicRequestError(code=self.error_code, message=self.error_message) + + def _apply_usage(self, usage: Mapping[str, object] | None) -> None: + if usage is None: + return + input_tokens = _as_int(usage.get("input_tokens")) + cache_creation_input_tokens = _as_int(usage.get("cache_creation_input_tokens")) + output_tokens = _as_int(usage.get("output_tokens")) + cached_input_tokens = _as_int(usage.get("cache_read_input_tokens")) + total_input_tokens = _total_input_tokens_for_log( + input_tokens, + cache_creation_input_tokens, + cached_input_tokens, + ) + + if total_input_tokens is not None: + self.input_tokens = total_input_tokens + if output_tokens is not None: + self.output_tokens = output_tokens + if cached_input_tokens is not None: + self.cached_input_tokens = cached_input_tokens + + +def _extract_request_model(payload: dict[str, JsonValue]) -> str: + value = payload.get("model") + if isinstance(value, str) and value.strip(): + return value.strip() + return "anthropic-unknown" + + +def _extract_response_model(payload: dict[str, JsonValue]) -> str | None: + value = payload.get("model") + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def _usage_from_message_payload(payload: dict[str, JsonValue]) -> AnthropicRequestUsage: + usage = payload.get("usage") + if not isinstance(usage, dict): + return AnthropicRequestUsage(input_tokens=None, output_tokens=None, cached_input_tokens=None) + input_tokens = _as_int(usage.get("input_tokens")) + cache_creation_input_tokens = _as_int(usage.get("cache_creation_input_tokens")) + cached_input_tokens = _as_int(usage.get("cache_read_input_tokens")) + return AnthropicRequestUsage( + input_tokens=_total_input_tokens_for_log( + input_tokens, + cache_creation_input_tokens, + cached_input_tokens, + ), + output_tokens=_as_int(usage.get("output_tokens")), + cached_input_tokens=cached_input_tokens, + ) + + +def _extract_error(payload: dict[str, JsonValue]) -> AnthropicRequestError: + error_value = payload.get("error") + if isinstance(error_value, dict): + raw_code = error_value.get("code") + raw_type = error_value.get("type") + raw_message = error_value.get("message") + return AnthropicRequestError( + code=_normalize_error_code(raw_code, raw_type), + message=raw_message if isinstance(raw_message, str) else None, + ) + + raw_type = payload.get("type") + return AnthropicRequestError( + code=_normalize_error_code(None, raw_type), + message=None, + ) + + +def _normalize_error_code(raw_code: JsonValue, raw_type: JsonValue) -> str | None: + if isinstance(raw_code, str) and raw_code.strip(): + return raw_code.strip().lower() + if isinstance(raw_type, str) and raw_type.strip(): + normalized_type = raw_type.strip().lower() + if normalized_type in {"rate_limit_error", "overloaded_error"}: + return "rate_limit_exceeded" + if normalized_type in {"insufficient_quota", "quota_exceeded", "usage_not_included"}: + return normalized_type + if normalized_type == "authentication_error": + return "invalid_api_key" + return normalized_type + return None + + +def _as_int(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + return None + + +def _total_input_tokens_for_log( + input_tokens: int | None, + cache_creation_input_tokens: int | None, + cache_read_input_tokens: int | None, +) -> int | None: + if input_tokens is None and cache_creation_input_tokens is None and cache_read_input_tokens is None: + return None + return (input_tokens or 0) + (cache_creation_input_tokens or 0) + (cache_read_input_tokens or 0) + + +async def _create_message_with_transport( + transport: Literal["sdk", "api"], + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + account: Account, +) -> dict[str, JsonValue]: + if transport == "api": + credentials = anthropic_credentials_from_account(account) + return await core_create_message_api(payload, headers, credentials=credentials) + return await core_create_message(payload, headers) + + +async def _stream_messages_with_transport( + transport: Literal["sdk", "api"], + payload: dict[str, JsonValue], + headers: Mapping[str, str], + *, + account: Account, +) -> AsyncIterator[str]: + if transport == "api": + credentials = anthropic_credentials_from_account(account) + async for line in core_stream_messages_api(payload, headers, credentials=credentials): + yield line + return + async for line in core_stream_messages(payload, headers): + yield line diff --git a/app/modules/oauth/service.py b/app/modules/oauth/service.py index 7d59265a..59069183 100644 --- a/app/modules/oauth/service.py +++ b/app/modules/oauth/service.py @@ -26,7 +26,7 @@ generate_pkce_pair, request_device_code, ) -from app.core.config.settings import get_settings +from app.core.config.settings import Settings, get_settings from app.core.crypto import TokenEncryptor from app.core.plan_types import coerce_account_plan_type from app.core.utils.time import utcnow @@ -42,6 +42,7 @@ _async_sleep = asyncio.sleep _SUCCESS_TEMPLATE = Path(__file__).resolve().parent / "templates" / "oauth_success.html" +_ANTHROPIC_PROVIDER_PLACEHOLDER_TOKEN = "anthropic-provider" @dataclass @@ -132,7 +133,11 @@ async def start_oauth(self, request: OauthStartRequest) -> OauthStartResponse: force_method = (request.force_method or "").lower() if not force_method: accounts = await self._accounts_repo.list_accounts() - if accounts: + settings = get_settings() + has_existing_oauth_account = any( + not self._is_provider_seed_account(account, settings) for account in accounts + ) + if has_existing_oauth_account: async with self._store.lock: await self._store._cleanup_locked() self._store._state = OAuthState(status="success") @@ -342,6 +347,24 @@ async def _stop_callback_server(self) -> None: if server: await server.stop() + def _is_provider_seed_account(self, account: Account, settings: Settings) -> bool: + if account.id != settings.anthropic_default_account_id: + return False + if account.email != settings.anthropic_default_account_email: + return False + if account.chatgpt_account_id is not None: + return False + if not ( + account.access_token_encrypted == account.refresh_token_encrypted + and account.refresh_token_encrypted == account.id_token_encrypted + ): + return False + try: + decrypted = self._encryptor.decrypt(account.access_token_encrypted) + except Exception: + return False + return decrypted == _ANTHROPIC_PROVIDER_PLACEHOLDER_TOKEN + @staticmethod def _html_response(html: str) -> web.Response: return web.Response(text=html, content_type="text/html") diff --git a/app/modules/proxy/helpers.py b/app/modules/proxy/helpers.py index a32f844e..1d207d7b 100644 --- a/app/modules/proxy/helpers.py +++ b/app/modules/proxy/helpers.py @@ -43,7 +43,11 @@ def _header_account_id(account_id: str | None) -> str | None: def _select_accounts_for_limits(accounts: Iterable[Account]) -> list[Account]: - return [account for account in accounts if account.status not in (AccountStatus.DEACTIVATED, AccountStatus.PAUSED)] + return [ + account + for account in accounts + if account.status not in (AccountStatus.DEACTIVATED, AccountStatus.PAUSED) and account.chatgpt_account_id + ] def _summarize_window( diff --git a/app/modules/proxy/load_balancer.py b/app/modules/proxy/load_balancer.py index 7ed83afa..fd8076c6 100644 --- a/app/modules/proxy/load_balancer.py +++ b/app/modules/proxy/load_balancer.py @@ -343,6 +343,9 @@ def _state_from_account( def _filter_accounts_for_model(accounts: list[Account], model: str) -> list[Account]: + accounts = [account for account in accounts if account.chatgpt_account_id] + if not accounts: + return [] allowed_plans = get_model_registry().plan_types_for_model(model) if allowed_plans is None: return accounts diff --git a/app/modules/usage/updater.py b/app/modules/usage/updater.py index 76718960..179dec5f 100644 --- a/app/modules/usage/updater.py +++ b/app/modules/usage/updater.py @@ -66,6 +66,8 @@ async def refresh_accounts( now = utcnow() interval = settings.usage_refresh_interval_seconds for account in accounts: + if _is_anthropic_provider_account(account, settings): + continue if account.status == AccountStatus.DEACTIVATED: continue latest = latest_usage.get(account.id) @@ -256,3 +258,7 @@ def _reset_at(reset_at: int | None, reset_after_seconds: int | None, now_epoch: def _should_deactivate_for_usage_error(status_code: int) -> bool: return status_code in _DEACTIVATING_USAGE_STATUS_CODES + + +def _is_anthropic_provider_account(account: Account, settings) -> bool: + return account.id == settings.anthropic_default_account_id or account.id.startswith("anthropic_") diff --git a/frontend/src/features/accounts/api.ts b/frontend/src/features/accounts/api.ts index bb964553..03a5ceca 100644 --- a/frontend/src/features/accounts/api.ts +++ b/frontend/src/features/accounts/api.ts @@ -27,6 +27,15 @@ export function importAccount(file: File) { }); } +export function importAnthropicAccount({ file, email }: { file: File; email: string }) { + const formData = new FormData(); + formData.append("credentials_json", file); + formData.append("email", email); + return post(`${ACCOUNTS_BASE_PATH}/import-anthropic`, AccountImportResponseSchema, { + body: formData, + }); +} + export function pauseAccount(accountId: string) { return post( `${ACCOUNTS_BASE_PATH}/${encodeURIComponent(accountId)}/pause`, diff --git a/frontend/src/features/accounts/components/account-detail.tsx b/frontend/src/features/accounts/components/account-detail.tsx index 4c700532..0b34a89d 100644 --- a/frontend/src/features/accounts/components/account-detail.tsx +++ b/frontend/src/features/accounts/components/account-detail.tsx @@ -1,11 +1,12 @@ -import { User } from "lucide-react"; +import { Bot, SquareTerminal, User } from "lucide-react"; +import { cn } from "@/lib/utils"; import { AccountActions } from "@/features/accounts/components/account-actions"; import { AccountTokenInfo } from "@/features/accounts/components/account-token-info"; import { AccountUsagePanel } from "@/features/accounts/components/account-usage-panel"; import type { AccountSummary } from "@/features/accounts/schemas"; import { useAccountTrends } from "@/features/accounts/hooks/use-accounts"; -import { formatCompactAccountId } from "@/utils/account-identifiers"; +import { isAnthropicAccountId, providerLabelForAccountId } from "@/utils/account-provider"; export type AccountDetailProps = { account: AccountSummary | null; @@ -26,6 +27,7 @@ export function AccountDetail({ onDelete, onReauth, }: AccountDetailProps) { + void showAccountId; const { data: trends } = useAccountTrends(account?.accountId ?? null); if (!account) { @@ -40,23 +42,41 @@ export function AccountDetail({ ); } + const isAnthropic = isAnthropicAccountId(account.accountId); + const ProviderIcon = isAnthropic ? Bot : SquareTerminal; + const providerLabel = providerLabelForAccountId(account.accountId); const title = account.displayName || account.email; - const compactId = formatCompactAccountId(account.accountId); const emailSubtitle = account.displayName && account.displayName !== account.email ? account.email : null; - const heading = showAccountId && !emailSubtitle ? `${title} (${compactId})` : title; - const subtitle = showAccountId && emailSubtitle ? `${emailSubtitle} | ID ${compactId}` : emailSubtitle; + const subtitle = emailSubtitle ? `${emailSubtitle} | ${providerLabel}` : providerLabel; return ( -
+
{subtitle}
) : null} diff --git a/frontend/src/features/accounts/components/account-list-item.tsx b/frontend/src/features/accounts/components/account-list-item.tsx index 1a5d9e9a..b6005d14 100644 --- a/frontend/src/features/accounts/components/account-list-item.tsx +++ b/frontend/src/features/accounts/components/account-list-item.tsx @@ -1,8 +1,10 @@ +import { Bot, SquareTerminal } from "lucide-react"; + import { cn } from "@/lib/utils"; import { StatusBadge } from "@/components/status-badge"; import type { AccountSummary } from "@/features/accounts/schemas"; import { normalizeStatus, quotaBarColor, quotaBarTrack } from "@/utils/account-status"; -import { formatCompactAccountId } from "@/utils/account-identifiers"; +import { isAnthropicAccountId, providerLabelForAccountId } from "@/utils/account-provider"; import { formatSlug } from "@/utils/formatters"; export type AccountListItemProps = { @@ -30,12 +32,15 @@ function MiniQuotaBar({ percent }: { percent: number | null }) { export function AccountListItem({ account, selected, showAccountId = false, onSelect }: AccountListItemProps) { const status = normalizeStatus(account.status); + const isAnthropic = isAnthropicAccountId(account.accountId); + const ProviderIcon = isAnthropic ? Bot : SquareTerminal; + const providerLabel = providerLabelForAccountId(account.accountId); const title = account.displayName || account.email; const baseSubtitle = account.displayName && account.displayName !== account.email ? account.email : formatSlug(account.planType); const subtitle = showAccountId - ? `${baseSubtitle} | ID ${formatCompactAccountId(account.accountId)}` + ? `${baseSubtitle} | ${providerLabel}` : baseSubtitle; const secondary = account.usage?.secondaryRemainingPercent ?? null; @@ -44,16 +49,30 @@ export function AccountListItem({ account, selected, showAccountId = false, onSe type="button" onClick={() => onSelect(account.accountId)} className={cn( - "w-full rounded-lg px-3 py-2.5 text-left transition-colors", - selected - ? "bg-primary/8 ring-1 ring-primary/25" - : "hover:bg-muted/50", + "w-full rounded-lg border px-3 py-2.5 text-left transition-colors", + isAnthropic && !selected && "border-amber-500/20 bg-amber-500/8 hover:bg-amber-500/14", + selected && isAnthropic && "border-amber-500/35 bg-amber-500/18 ring-1 ring-amber-500/25", + selected && !isAnthropic && "border-primary/20 bg-primary/8 ring-1 ring-primary/25", + !isAnthropic && !selected && "border-transparent hover:bg-muted/50", )} >{title}
-+
{title}
+{subtitle}