Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ CODEX_LB_OAUTH_CALLBACK_HOST=127.0.0.1
CODEX_LB_OAUTH_CALLBACK_PORT=1455
CODEX_LB_TOKEN_REFRESH_TIMEOUT_SECONDS=30
CODEX_LB_TOKEN_REFRESH_INTERVAL_DAYS=8
# Optional direct refresh endpoint override (used by refresh token exchange)
# CODEX_REFRESH_TOKEN_URL_OVERRIDE=https://auth.openai.com/oauth/token

# Encryption key file (optional override; recommended for Docker volumes)
# CODEX_LB_ENCRYPTION_KEY_FILE=/var/lib/codex-lb/encryption.key
Expand All @@ -37,6 +39,13 @@ CODEX_LB_USAGE_REFRESH_INTERVAL_SECONDS=60
CODEX_LB_STICKY_SESSION_CLEANUP_ENABLED=true
CODEX_LB_STICKY_SESSION_CLEANUP_INTERVAL_SECONDS=300

# Optional outbound HTTP proxy for upstream/OAuth/model requests
# CODEX_LB_HTTP_PROXY_URL=http://127.0.0.1:8080

# Optional additional proxy request guard (in addition to Authorization auth, if enabled)
# CODEX_LB_PROXY_KEY_AUTH_ENABLED=false
# CODEX_LB_PROXY_KEY=replace-with-strong-random-shared-key

# Firewall
# Trust X-Forwarded-For for firewall client IP detection (enable only behind trusted reverse proxy)
CODEX_LB_FIREWALL_TRUST_PROXY_HEADERS=false
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ __pycache__/
.env.*
!.env.example
.python-version
*.iml

# Build artifacts
build/
Expand Down
5 changes: 5 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@Library(['dada-tuda-jenkins-pipelines@develop', 'maven-lib@1.0.10']) _
pythonPipeline(
projectName: "codex-proxy",
python: "3.14"
)
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Load balancer for ChatGPT accounts. Pool multiple accounts, track usage, manage
</tr>
</table>

- Accounts UI supports batch `auth.json` import and one-click auth archive export.

## Quick Start

```bash
Expand Down Expand Up @@ -283,13 +285,16 @@ When enabled, clients must pass a valid API key as a Bearer token:
Authorization: Bearer sk-clb-...
```

Optional extra hardening: enable `CODEX_LB_PROXY_KEY_AUTH_ENABLED=true` with `CODEX_LB_PROXY_KEY` to require `X-Codex-Proxy-Key` on proxy requests in addition to Bearer auth.

**Creating keys**: Dashboard → API Keys → Create. The full key is shown **only once** at creation. Keys support optional expiration, model restrictions, and rate limits (tokens / cost per day / week / month).

## Configuration

Environment variables with `CODEX_LB_` prefix or `.env.local`. See [`.env.example`](.env.example).
Dashboard auth is configured in Settings.
SQLite is the default database backend; PostgreSQL is optional via `CODEX_LB_DATABASE_URL` (for example `postgresql+asyncpg://...`).
Container startup also honors `PORT` and auto-loads `/app/.env` when that file is mounted.

## Data

Expand Down
17 changes: 15 additions & 2 deletions app/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import argparse
import logging
import os

import uvicorn

from app.core.runtime_logging import build_log_config
from app.core.logging import build_log_config, configure_logging

logger = logging.getLogger(__name__)


def _parse_args() -> argparse.Namespace:
Expand All @@ -24,13 +27,23 @@ def main() -> None:
if bool(args.ssl_certfile) ^ bool(args.ssl_keyfile):
raise SystemExit("Both --ssl-certfile and --ssl-keyfile must be provided together.")

log_level = configure_logging()
logger.info(
"Starting codex-lb host=%s port=%s ssl=%s log_level=%s access_log=%s",
args.host,
args.port,
bool(args.ssl_certfile and args.ssl_keyfile),
log_level,
False,
)
uvicorn.run(
"app.main:app",
host=args.host,
port=args.port,
ssl_certfile=args.ssl_certfile,
ssl_keyfile=args.ssl_keyfile,
log_config=build_log_config(),
access_log=False,
log_config=build_log_config(log_level),
)


Expand Down
14 changes: 13 additions & 1 deletion app/core/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import hashlib
import json
from dataclasses import dataclass
from datetime import datetime
from datetime import UTC, datetime
from uuid import uuid4

from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -79,6 +79,18 @@ def extract_id_token_claims(id_token: str) -> IdTokenClaims:
return IdTokenClaims()


def token_expiry(token: str | None) -> datetime | None:
if not token:
return None
claims = extract_id_token_claims(token)
exp = claims.exp
if isinstance(exp, (int, float)):
return datetime.fromtimestamp(exp, tz=UTC)
if isinstance(exp, str) and exp.isdigit():
return datetime.fromtimestamp(int(exp), tz=UTC)
return None


def claims_from_auth(auth: AuthFile) -> AccountClaims:
claims = extract_id_token_claims(auth.tokens.id_token)
auth_claims = claims.auth or OpenAIAuthClaims()
Expand Down
23 changes: 23 additions & 0 deletions app/core/auth/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import logging
import secrets

from fastapi import Request, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.core.clients.usage import UsageFetchError, fetch_usage
from app.core.config.settings import get_settings
from app.core.config.settings_cache import get_settings_cache
from app.core.exceptions import DashboardAuthError, ProxyAuthError, ProxyUpstreamError
from app.db.session import get_background_session
Expand Down Expand Up @@ -34,8 +36,10 @@ def set_dashboard_error_format(request: Request) -> None:


async def validate_proxy_api_key(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Security(_bearer),
) -> ApiKeyData | None:
_validate_optional_proxy_key_header(request)
authorization = None if credentials is None else f"Bearer {credentials.credentials}"
return await validate_proxy_api_key_authorization(authorization)

Expand Down Expand Up @@ -124,3 +128,22 @@ def _extract_bearer_token(authorization: str | None) -> str | None:
if not token:
return None
return token


def _validate_optional_proxy_key_header(request: Request) -> None:
settings = get_settings()
if not settings.proxy_key_auth_enabled:
return

required_key = settings.proxy_key
if not required_key:
raise ProxyAuthError("X-Codex-Proxy-Key auth is enabled but no proxy key is configured")

provided = request.headers.get("X-Codex-Proxy-Key")
if not provided:
raise ProxyAuthError("Missing X-Codex-Proxy-Key header")
provided_key = provided.strip()
if not provided_key:
raise ProxyAuthError("Missing X-Codex-Proxy-Key header")
if not secrets.compare_digest(provided_key, required_key):
raise ProxyAuthError("Invalid X-Codex-Proxy-Key header")
62 changes: 57 additions & 5 deletions app/core/auth/refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextvars
import logging
import os
from dataclasses import dataclass
from datetime import datetime, timedelta

Expand All @@ -11,13 +12,15 @@
from app.core.auth import OpenAIAuthClaims, extract_id_token_claims
from app.core.auth.models import OAuthTokenPayload
from app.core.balancer import PERMANENT_FAILURE_CODES
from app.core.clients.http import get_http_client
from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs
from app.core.config.settings import get_settings
from app.core.types import JsonObject
from app.core.utils.request_id import get_request_id
from app.core.utils.time import to_utc_naive, utcnow

TOKEN_REFRESH_INTERVAL_DAYS = 8
DEFAULT_REFRESH_TOKEN_URL = "https://auth.openai.com/oauth/token"
REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR = "CODEX_REFRESH_TOKEN_URL_OVERRIDE"

logger = logging.getLogger(__name__)
_TOKEN_REFRESH_TIMEOUT_OVERRIDE: contextvars.ContextVar[float | None] = contextvars.ContextVar(
Expand Down Expand Up @@ -57,13 +60,20 @@ def classify_refresh_error(code: str | None) -> bool:
return code in PERMANENT_FAILURE_CODES


def refresh_token_endpoint() -> str:
override = os.getenv(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR)
if override and override.strip():
return override.strip()
return DEFAULT_REFRESH_TOKEN_URL


async def refresh_access_token(
refresh_token: str,
*,
session: aiohttp.ClientSession | None = None,
) -> TokenRefreshResult:
settings = get_settings()
url = f"{settings.auth_base_url.rstrip('/')}/oauth/token"
url = refresh_token_endpoint()
payload = {
"grant_type": "refresh_token",
"client_id": settings.oauth_client_id,
Expand All @@ -77,7 +87,8 @@ async def refresh_access_token(
request_id = get_request_id()
if request_id:
headers["x-request-id"] = request_id
async with client_session.post(url, json=payload, headers=headers, timeout=timeout) as resp:
proxy_kwargs = await get_http_proxy_request_kwargs()
async with client_session.post(url, json=payload, headers=headers, timeout=timeout, **proxy_kwargs) as resp:
data = await _safe_json(resp)
try:
payload_data = OAuthTokenPayload.model_validate(data)
Expand Down Expand Up @@ -132,9 +143,9 @@ async def _safe_json(resp: aiohttp.ClientResponse) -> JsonObject:


def _refresh_error_from_payload(payload: OAuthTokenPayload, status_code: int) -> RefreshError:
code = _extract_error_code(payload) or f"http_{status_code}"
message = _extract_error_message(payload) or f"Token refresh failed ({status_code})"
return RefreshError(code, message, classify_refresh_error(code))
code = _normalize_refresh_error_code(_extract_error_code(payload), message, status_code)
return RefreshError(code, message, _is_permanent_refresh_failure(code, message, status_code))


def _effective_token_refresh_timeout(configured_timeout_seconds: float) -> float:
Expand Down Expand Up @@ -162,3 +173,44 @@ def _extract_error_message(payload: OAuthTokenPayload) -> str | None:
if isinstance(error, str):
return payload.error_description or error
return payload.message


def _is_permanent_refresh_failure(code: str | None, message: str, status_code: int) -> bool:
if classify_refresh_error(code):
return True

normalized_code = (code or "").strip().lower()
normalized_message = message.strip().lower()
if status_code != 401:
return False

permanent_codes = {
"invalid_grant",
"token_expired",
"session_expired",
}
if normalized_code in permanent_codes:
return True

permanent_message_fragments = (
"refresh token has already been used",
"provided authentication token is expired",
"please try signing in again",
"re-login required",
"token is expired",
"token expired",
)
return any(fragment in normalized_message for fragment in permanent_message_fragments)


def _normalize_refresh_error_code(code: str | None, message: str, status_code: int) -> str:
normalized_code = (code or "").strip().lower()
normalized_message = message.strip().lower()

if status_code == 401:
if "refresh token has already been used" in normalized_message:
return "refresh_token_reused"
if "provided authentication token is expired" in normalized_message or "token expired" in normalized_message:
return "refresh_token_expired"

return code or f"http_{status_code}"
4 changes: 3 additions & 1 deletion app/core/clients/codex_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import aiohttp
import anyio

from app.core.clients.http import get_http_proxy_request_kwargs
from app.core.config.settings import get_settings

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,9 +60,10 @@ async def invalidate(self) -> None:
async def _fetch_latest_version(self) -> str | None:
try:
timeout = aiohttp.ClientTimeout(total=_FETCH_TIMEOUT_SECONDS)
proxy_kwargs = await get_http_proxy_request_kwargs()
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
headers = {"Accept": "application/vnd.github+json"}
async with session.get(_GITHUB_RELEASES_URL, headers=headers) as resp:
async with session.get(_GITHUB_RELEASES_URL, headers=headers, **proxy_kwargs) as resp:
if resp.status != 200:
logger.warning("GitHub releases API returned HTTP %d", resp.status)
return None
Expand Down
23 changes: 22 additions & 1 deletion app/core/clients/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import aiohttp
from aiohttp_retry import RetryClient

from app.core.config.proxy import normalize_http_proxy_url
from app.core.config.settings import get_settings
from app.core.config.settings_cache import get_settings_cache
from app.core.config.settings import get_settings


@dataclass(slots=True)
class HttpClient:
Expand Down Expand Up @@ -59,3 +61,22 @@ def get_http_client() -> HttpClient:
if _http_client is None:
raise RuntimeError("HTTP client not initialized")
return _http_client


async def get_http_proxy_url() -> str | None:
env_proxy = normalize_http_proxy_url(get_settings().http_proxy_url)
if env_proxy:
return env_proxy

try:
settings_row = await get_settings_cache().get()
except Exception:
return None
return normalize_http_proxy_url(getattr(settings_row, "http_proxy_url", None))


async def get_http_proxy_request_kwargs() -> dict[str, str]:
proxy = await get_http_proxy_url()
if not proxy:
return {}
return {"proxy": proxy}
5 changes: 3 additions & 2 deletions app/core/clients/model_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import aiohttp

from app.core.clients.codex_version import get_codex_version_cache
from app.core.clients.http import get_http_client
from app.core.clients.http import get_http_client, get_http_proxy_request_kwargs
from app.core.config.settings import get_settings
from app.core.openai.model_registry import ReasoningLevel, UpstreamModel
from app.core.types import JsonValue
Expand Down Expand Up @@ -99,8 +99,9 @@ async def fetch_models_for_plan(

timeout = aiohttp.ClientTimeout(total=_FETCH_TIMEOUT_SECONDS)
session = get_http_client().session
proxy_kwargs = await get_http_proxy_request_kwargs()

async with session.get(url, headers=headers, timeout=timeout) as resp:
async with session.get(url, headers=headers, timeout=timeout, **proxy_kwargs) as resp:
if resp.status >= 400:
text = await resp.text()
raise ModelFetchError(resp.status, f"HTTP {resp.status}: {text[:200]}")
Expand Down
Loading