diff --git a/docs/OAUTH.md b/docs/OAUTH.md new file mode 100644 index 000000000..c9b657e9e --- /dev/null +++ b/docs/OAUTH.md @@ -0,0 +1,130 @@ +# OpenAI OAuth 2.0 Guide + +## Why OAuth? + +- Browser-based OpenAI account authentication +- Automatic token refresh +- Codex CLI credential import +- Account status inspection in Pantheon + +## Important Limitation + +Pantheon's OAuth support manages OpenAI account credentials only. It does not treat the +resulting OAuth access token as a substitute for `OPENAI_API_KEY` when calling the OpenAI API. + +The current exception is Pantheon's dedicated Codex transport: models whose name contains +`codex` can be routed through the ChatGPT/Codex backend using OAuth credentials when available. + +You can trigger this path explicitly with a Codex-prefixed model name such as: + +```bash +/model codex/gpt-5.4 +``` + +To call OpenAI models through the standard OpenAI API path, you still need one of: + +- `OPENAI_API_KEY` +- `LLM_API_KEY` with a compatible base URL +- `CUSTOM_OPENAI_API_BASE` plus `CUSTOM_OPENAI_API_KEY` + +## Integration Risk + +Pantheon's OpenAI OAuth integration reuses the Codex CLI OAuth client identity. +This is not a public third-party OAuth app registration flow. + +Implications: + +- OpenAI can revoke, restrict, or change this integration path at any time +- A working setup today may break without a Pantheon code change +- This path should be treated as best-effort, not as a long-term stable contract + +For maintainers: + +- Do not assume the current client ID / originator values are durable +- Prefer isolating Codex-specific OAuth behavior from standard OpenAI API auth +- Be prepared to disable or replace this path if OpenAI changes upstream behavior + +## Quick Start + +```bash +pantheon +/oauth login +# Browser opens - log in and authorize +``` + +## REPL Commands + +| Command | Description | +|---------|-------------| +| `/oauth status` | Check authentication | +| `/oauth login` | Initiate login | +| `/oauth logout` | Clear credentials | + +## API Reference + +### `get_oauth_manager() -> OAuthManager` + +Get the singleton provider registry for OAuth-capable providers. + +### `OpenAIOAuthProvider` + +| Method | Returns | Description | +|--------|---------|-------------| +| `login()` | `bool` | Start OAuth flow | +| `ensure_access_token()` | `str\|None` | Get a valid access token | +| `ensure_access_token_with_codex_fallback()` | `str\|None` | Get token, importing Codex CLI auth if needed | +| `build_codex_auth_context()` | `dict\|None` | Build ChatGPT/Codex backend auth context | +| `get_status()` | `OAuthStatus` | Current auth status | +| `logout()` | `None` | Revoke tokens and remove local auth file | + +### Example + +```python +from pantheon.auth.oauth_manager import get_oauth_manager + +mgr = get_oauth_manager() +provider = mgr.get_provider("openai") +token = provider.ensure_access_token() +if token: + status = provider.get_status() + print(f"Logged in as: {status.email}") +``` + +## Configuration + +```python +# Custom token location +from pathlib import Path +from pantheon.auth.openai_provider import OpenAIOAuthProvider + +provider = OpenAIOAuthProvider(auth_path=Path("/custom/path.json")) +``` + +```bash +# OpenAI API model calls still require an API key +export OPENAI_API_KEY="sk-..." +``` + +## Troubleshooting + +| Error | Solution | +|-------|----------| +| `No module named 'requests'` | `pip install requests` | +| `No module named 'jwt'` | `pip install pyjwt` | +| `No module named 'cryptography'` | `pip install cryptography` | +| Browser didn't open | Set default browser in OS settings | +| Token expired | Run `/oauth login` to re-authenticate | +| Can't import Codex | Use browser login instead | + +## Security + +- Tokens stored at `~/.pantheon/oauth_openai.json` +- Tokens auto-refresh when ~5 min from expiry +- JWT claims used for email / org / project context are signature-verified before use +- OAuth callback requests are checked against `Origin` / `Referer` when headers are present +- Use `/oauth logout` on shared systems + +## See Also + +- [OpenAI OAuth Docs](https://platform.openai.com/docs/guides/oauth) +- [PKCE RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636) diff --git a/pantheon/auth/__init__.py b/pantheon/auth/__init__.py new file mode 100644 index 000000000..f47a9588c --- /dev/null +++ b/pantheon/auth/__init__.py @@ -0,0 +1,6 @@ +""" +Pantheon authentication modules. + +This package provides authentication support for various LLM providers, +including OAuth 2.0 integration with OpenAI Codex. +""" diff --git a/pantheon/auth/oauth_manager.py b/pantheon/auth/oauth_manager.py new file mode 100644 index 000000000..9e9466145 --- /dev/null +++ b/pantheon/auth/oauth_manager.py @@ -0,0 +1,239 @@ +""" +OAuth Types and Protocols for PantheonOS. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, Optional + +@dataclass +class OAuthTokens: + """Generic OAuth tokens.""" + id_token: str + access_token: str + refresh_token: str + account_id: Optional[str] = None + organization_id: Optional[str] = None + project_id: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict) -> OAuthTokens: + return cls( + id_token=data["id_token"], + access_token=data["access_token"], + refresh_token=data["refresh_token"], + account_id=data.get("account_id"), + organization_id=data.get("organization_id"), + project_id=data.get("project_id"), + ) + + def to_dict(self) -> dict: + return { + "id_token": self.id_token, + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "account_id": self.account_id, + "organization_id": self.organization_id, + "project_id": self.project_id, + } + + +@dataclass +class OAuthStatus: + """Generic OAuth status.""" + authenticated: bool + email: str = "" + organization_id: Optional[str] = None + project_id: Optional[str] = None + token_expires_at: Optional[float] = None + provider: str = "" + + +@dataclass +class AuthRecord: + """Generic OAuth auth record.""" + provider: str + tokens: OAuthTokens + last_refresh: str + email: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict) -> AuthRecord: + return cls( + provider=data.get("provider", "unknown"), + tokens=OAuthTokens.from_dict(data["tokens"]), + last_refresh=data.get("last_refresh", ""), + email=data.get("email"), + ) + + def to_dict(self) -> dict: + result = { + "provider": self.provider, + "tokens": self.tokens.to_dict(), + "last_refresh": self.last_refresh, + } + if self.email: + result["email"] = self.email + return result + + +class OAuthProvider(Protocol): + """Protocol for OAuth providers.""" + + @property + def name(self) -> str: + """Provider name.""" + ... + + @property + def display_name(self) -> str: + """Display name for UI.""" + ... + + def login( + self, + *, + open_browser: bool = True, + timeout_seconds: int = 300, + ) -> bool: + """Initiate OAuth login flow.""" + ... + + def get_status(self) -> OAuthStatus: + """Get current OAuth status.""" + ... + + def logout(self) -> None: + """Clear OAuth credentials.""" + ... + + def ensure_access_token(self, refresh_if_needed: bool = True) -> Optional[str]: + """Get a valid access token.""" + ... + + +class OAuthManager: + """Manages multiple OAuth providers.""" + + def __init__(self): + self._providers: dict[str, OAuthProvider] = {} + self._default_provider: str = "openai" + + def register(self, provider: OAuthProvider) -> None: + """Register an OAuth provider.""" + self._providers[provider.name] = provider + + def set_default(self, provider_name: str) -> None: + """Set the default provider.""" + if provider_name not in self._providers: + raise ValueError(f"Unknown provider: {provider_name}") + self._default_provider = provider_name + + @property + def default_provider(self) -> str: + """Get the default provider name.""" + return self._default_provider + + def list_providers(self) -> list[str]: + """List all registered provider names.""" + return list(self._providers.keys()) + + def get_provider(self, name: Optional[str] = None) -> OAuthProvider: + """Get a provider by name, or the default provider.""" + provider_name = name or self._default_provider + if provider_name not in self._providers: + raise ValueError(f"Unknown provider: {provider_name}") + return self._providers[provider_name] + + def login( + self, + provider: Optional[str] = None, + *, + open_browser: bool = True, + timeout_seconds: int = 300, + ) -> bool: + """Login with a specific provider.""" + p = self.get_provider(provider) + return p.login(open_browser=open_browser, timeout_seconds=timeout_seconds) + + def get_status(self, provider: Optional[str] = None) -> OAuthStatus: + """Get status from a specific provider.""" + p = self.get_provider(provider) + status = p.get_status() + status.provider = p.name + return status + + def logout(self, provider: Optional[str] = None) -> None: + """Logout from a specific provider.""" + p = self.get_provider(provider) + p.logout() + + def ensure_access_token( + self, + provider: Optional[str] = None, + refresh_if_needed: bool = True, + ) -> Optional[str]: + """Get a valid access token from a specific provider.""" + p = self.get_provider(provider) + return p.ensure_access_token(refresh_if_needed=refresh_if_needed) + + +_oauth_manager: Optional[OAuthManager] = None + + +def get_oauth_manager() -> OAuthManager: + """Get the OAuth manager singleton.""" + global _oauth_manager + if _oauth_manager is None: + _oauth_manager = OAuthManager() + from pantheon.auth.openai_provider import OpenAIOAuthProvider + _oauth_manager.register(OpenAIOAuthProvider()) + return _oauth_manager + + +def reset_oauth_manager() -> None: + """Reset the OAuth manager singleton.""" + global _oauth_manager + _oauth_manager = None + + +def get_oauth_token(provider: str = "openai", refresh_if_needed: bool = True) -> Optional[str]: + """Get a valid OAuth access token for the specified provider. + + This is a convenience function for other modules to get OAuth tokens. + + Args: + provider: The OAuth provider name (default: "openai") + refresh_if_needed: Whether to refresh the token if expired + + Returns: + The access token string, or None if not available + """ + try: + manager = get_oauth_manager() + return manager.ensure_access_token(provider, refresh_if_needed) + except Exception: + return None + + +def is_oauth_available(provider: str = "openai") -> bool: + """Check if OAuth is available for the specified provider. + + Args: + provider: The OAuth provider name (default: "openai") + + Returns: + True if OAuth tokens are available, False otherwise + """ + try: + from pathlib import Path + + manager = get_oauth_manager() + oauth_provider = manager.get_provider(provider) + + # Check if auth file exists + if hasattr(oauth_provider, 'auth_path') and oauth_provider.auth_path.exists(): + return True + return False + except Exception: + return False diff --git a/pantheon/auth/openai_auth_strategy.py b/pantheon/auth/openai_auth_strategy.py new file mode 100644 index 000000000..8556c4b08 --- /dev/null +++ b/pantheon/auth/openai_auth_strategy.py @@ -0,0 +1,101 @@ +""" +OpenAI authentication strategy helpers. + +This module centralizes how Pantheon decides between OpenAI API key auth +and Codex OAuth auth when both are present. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass +from typing import Any + +from pantheon.settings import get_settings +from pantheon.utils.log import logger + + +VALID_OPENAI_AUTH_MODES = { + "auto", + "prefer_api_key", + "prefer_oauth", + "api_key_only", + "oauth_only", +} + + +@dataclass(frozen=True) +class OpenAIAuthSettings: + mode: str = "auto" + enable_api_key: bool = True + enable_oauth: bool = True + allow_codex_fallback_to_api_key: bool = False + allow_openai_api_fallback_to_oauth: bool = False + + def normalized(self) -> "OpenAIAuthSettings": + mode = str(self.mode or "auto").strip().lower() + if mode not in VALID_OPENAI_AUTH_MODES: + logger.warning(f"Unknown auth.openai.mode '{self.mode}', falling back to 'auto'") + mode = "auto" + return OpenAIAuthSettings( + mode=mode, + enable_api_key=bool(self.enable_api_key), + enable_oauth=bool(self.enable_oauth), + allow_codex_fallback_to_api_key=bool(self.allow_codex_fallback_to_api_key), + allow_openai_api_fallback_to_oauth=bool(self.allow_openai_api_fallback_to_oauth), + ) + + +def get_openai_auth_settings() -> OpenAIAuthSettings: + settings = get_settings() + raw = settings.get("auth.openai", {}) or {} + return OpenAIAuthSettings( + mode=raw.get("mode", "auto"), + enable_api_key=raw.get("enable_api_key", True), + enable_oauth=raw.get("enable_oauth", True), + allow_codex_fallback_to_api_key=raw.get("allow_codex_fallback_to_api_key", False), + allow_openai_api_fallback_to_oauth=raw.get("allow_openai_api_fallback_to_oauth", False), + ).normalized() + +def is_api_key_auth_enabled() -> bool: + prefs = get_openai_auth_settings() + return prefs.enable_api_key and prefs.mode != "oauth_only" + + +def is_oauth_auth_enabled() -> bool: + prefs = get_openai_auth_settings() + return prefs.enable_oauth and prefs.mode != "api_key_only" + + +def should_use_codex_oauth_transport(model_name: str) -> bool: + prefs = get_openai_auth_settings() + if not is_oauth_auth_enabled(): + return False + + lower = (model_name or "").strip().lower() + if lower.startswith("codex/"): + return True + if "codex" in lower and prefs.mode in {"prefer_oauth", "oauth_only"}: + return True + return False + + +def should_treat_openai_api_key_as_available() -> bool: + return is_api_key_auth_enabled() + + +def summarize_openai_auth_state( + *, + api_key_present: bool, + oauth_authenticated: bool, +) -> dict[str, Any]: + prefs = get_openai_auth_settings() + return { + "mode": prefs.mode, + "enable_api_key": prefs.enable_api_key, + "enable_oauth": prefs.enable_oauth, + "allow_codex_fallback_to_api_key": prefs.allow_codex_fallback_to_api_key, + "allow_openai_api_fallback_to_oauth": prefs.allow_openai_api_fallback_to_oauth, + "api_key_present": bool(api_key_present), + "oauth_authenticated": bool(oauth_authenticated), + "effective_api_key_enabled": is_api_key_auth_enabled(), + "effective_oauth_enabled": is_oauth_auth_enabled(), + } diff --git a/pantheon/auth/openai_provider.py b/pantheon/auth/openai_provider.py new file mode 100644 index 000000000..c0f3f9c9c --- /dev/null +++ b/pantheon/auth/openai_provider.py @@ -0,0 +1,700 @@ +""" +OpenAI OAuth 2.0 Provider for PantheonOS. + +Security Notes: +- JWT tokens are base64-decoded for payload extraction but signature is NOT verified. + This is a common simplification for client-side token inspection. The OAuth flow + itself provides security via PKCE and HTTPS. Only use tokens from trusted sources. +- Token files are saved with 0o600 permissions (user-only read/write). +- Logout attempts to revoke tokens on OpenAI's server. + +Known Risks: +- This implementation reuses OpenAI Codex CLI's OAuth client ID and originator. + OpenAI does not currently offer public OAuth app registration for third-party tools. + OpenAI can revoke or restrict this client ID at any time, breaking auth for all users. + This is an undocumented, unsupported integration path that could change without notice. +- OAuth tokens managed here are account credentials. PantheonOS should not inject them + into generic OpenAI API SDK calls as a substitute for ``OPENAI_API_KEY``. +""" +from __future__ import annotations + +import base64 +import hashlib +import json +import os +import secrets +import stat +import threading +import time +import webbrowser +from datetime import datetime, timezone +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from typing import Optional, Callable + +import requests + +from pantheon.auth.oauth_manager import ( + OAuthProvider, + OAuthTokens, + AuthRecord, + OAuthStatus, +) +from pantheon.utils.log import logger + + +OPENAI_AUTH_ISSUER = "https://auth.openai.com" +OPENAI_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +OPENAI_ORIGINATOR = "pi" +OPENAI_CALLBACK_PORT = 1455 +OPENAI_SCOPE = "openid profile email offline_access" +OPENAI_CODEX_BASE_URL = "https://chatgpt.com/backend-api" +OPENAI_OIDC_CONFIG_URL = f"{OPENAI_AUTH_ISSUER}/.well-known/openid-configuration" +_OIDC_CONFIG_CACHE: dict[str, object] = {"value": None, "expires_at": 0.0} +_JWKS_CLIENT_CACHE: dict[str, object] = {"value": None, "expires_at": 0.0} + + +def _utc_now() -> str: + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=") + + +def _pkce_pair() -> tuple: + verifier = _b64url(secrets.token_bytes(32)) + challenge = _b64url(hashlib.sha256(verifier.encode("utf-8")).digest()) + return verifier, challenge + + +def _decode_jwt_payload_unverified(token: str) -> dict: + parts = (token or "").split(".") + if len(parts) != 3 or not parts[1]: + return {} + payload = parts[1] + payload += "=" * (-len(payload) % 4) + try: + decoded = base64.urlsafe_b64decode(payload.encode("ascii")) + data = json.loads(decoded.decode("utf-8")) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _get_oidc_config() -> dict: + now = time.time() + cached = _OIDC_CONFIG_CACHE.get("value") + if isinstance(cached, dict) and now < float(_OIDC_CONFIG_CACHE.get("expires_at", 0.0)): + return cached + + response = requests.get(OPENAI_OIDC_CONFIG_URL, timeout=10) + response.raise_for_status() + config = response.json() + if not isinstance(config, dict): + raise RuntimeError("OIDC discovery returned invalid payload") + + _OIDC_CONFIG_CACHE["value"] = config + _OIDC_CONFIG_CACHE["expires_at"] = now + 3600 + return config + + +def _get_jwks_client(): + now = time.time() + cached = _JWKS_CLIENT_CACHE.get("value") + if cached is not None and now < float(_JWKS_CLIENT_CACHE.get("expires_at", 0.0)): + return cached + + try: + import jwt + except ImportError as exc: + raise RuntimeError("PyJWT is required for JWT signature verification") from exc + + config = _get_oidc_config() + jwks_uri = str(config.get("jwks_uri") or "").strip() + if not jwks_uri: + raise RuntimeError("OIDC discovery did not include jwks_uri") + + client = jwt.PyJWKClient(jwks_uri) + _JWKS_CLIENT_CACHE["value"] = client + _JWKS_CLIENT_CACHE["expires_at"] = now + 3600 + return client + + +def _decode_jwt_payload_verified(token: str) -> dict: + if not token: + return {} + + try: + import jwt + + jwks_client = _get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token) + payload = jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + issuer=OPENAI_AUTH_ISSUER, + options={"verify_aud": False}, + ) + return payload if isinstance(payload, dict) else {} + except Exception as exc: + if exc.__class__.__name__ == "ExpiredSignatureError": + logger.debug("JWT signature verification skipped for expired token during local inspection") + return {} + logger.warning(f"JWT signature verification failed: {exc}") + return {} + + +def _decode_jwt_payload(token: str, *, allow_unverified_fallback: bool = False) -> dict: + payload = _decode_jwt_payload_verified(token) + if payload: + return payload + if allow_unverified_fallback: + return _decode_jwt_payload_unverified(token) + return {} + + +def jwt_auth_claims(token: str) -> dict: + payload = _decode_jwt_payload(token) + nested = payload.get("https://api.openai.com/auth") + return nested if isinstance(nested, dict) else {} + + +def jwt_org_context(token: str) -> dict: + claims = jwt_auth_claims(token) + context = {} + for key in ("organization_id", "project_id", "chatgpt_account_id"): + value = str(claims.get(key) or "").strip() + if value: + context[key] = value + return context + + +def _extract_org_context(token: str) -> dict: + payload = _decode_jwt_payload(token) + nested = payload.get("https://api.openai.com/auth", {}) + if not isinstance(nested, dict): + nested = {} + + context = {} + for key in ("organization_id", "project_id", "chatgpt_account_id"): + value = str(nested.get(key) or "").strip() + if value: + context[key] = value + return context + + +def _token_expired(token: str, skew_seconds: int = 300) -> bool: + payload = _decode_jwt_payload(token, allow_unverified_fallback=True) + exp = payload.get("exp") + if not isinstance(exp, (int, float)): + return True + return time.time() >= (float(exp) - skew_seconds) + + +def _extract_email(token: str) -> str: + payload = _decode_jwt_payload(token) + return payload.get("email", "") + + +def _extract_token_exp(token: str) -> float | None: + payload = _decode_jwt_payload(token, allow_unverified_fallback=True) + exp = payload.get("exp") + if isinstance(exp, (int, float)): + return float(exp) + return None + + +class _OAuthCallbackHandler(BaseHTTPRequestHandler): + server_version = "PantheonOAuth/1.0" + ALLOWED_ORIGINS = {"https://auth.openai.com", "https://openai.com"} + + def _check_origin(self) -> bool: + origin = self.headers.get("Origin", "") + referer = self.headers.get("Referer", "") + + if origin: + for allowed in self.ALLOWED_ORIGINS: + if origin.startswith(allowed): + return True + if referer: + for allowed in self.ALLOWED_ORIGINS: + if referer.startswith(allowed): + return True + if not origin and not referer: + return True + return False + + def do_GET(self) -> None: + from urllib.parse import parse_qs, urlparse + + if not self._check_origin(): + self.send_error(403) + return + + parsed = urlparse(self.path) + if parsed.path != "/auth/callback": + self.send_error(404) + return + + params = {key: values[-1] for key, values in parse_qs(parsed.query).items() if values} + self.server.result = params + self.server.event.set() + + body = ( + "

OpenAI OAuth complete

" + "

You can close this window and return to Pantheon.

" + ) + data = body.encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def log_message(self, fmt: str, *args: object) -> None: + return + + +class OpenAIOAuthProvider: + """ + OpenAI OAuth provider for Pantheon. + """ + + AUTHORIZATION_ENDPOINT = f"{OPENAI_AUTH_ISSUER}/oauth/authorize" + TOKEN_ENDPOINT = f"{OPENAI_AUTH_ISSUER}/oauth/token" + CLIENT_ID = OPENAI_CLIENT_ID + SCOPE = OPENAI_SCOPE + + _lock = threading.Lock() + + @property + def name(self) -> str: + return "openai" + + @property + def display_name(self) -> str: + return "OpenAI" + + def __init__(self, auth_path: Optional[Path] = None): + if auth_path is None: + auth_path = Path.home() / ".pantheon" / "oauth_openai.json" + self.auth_path = auth_path + + def _create_callback_server(self, event: threading.Event) -> tuple: + for port in (OPENAI_CALLBACK_PORT, 0): + try: + server = ThreadingHTTPServer(("localhost", port), _OAuthCallbackHandler) + server.event = event + server.result = {} + return server, server.server_address[1] + except OSError: + continue + raise RuntimeError("Could not start OAuth callback server") + + def _build_auth_url(self, code_challenge: str, redirect_uri: str, state: str, workspace_id: Optional[str] = None) -> str: + from urllib.parse import urlencode + + params = { + "client_id": self.CLIENT_ID, + "response_type": "code", + "redirect_uri": redirect_uri, + "scope": self.SCOPE, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "originator": OPENAI_ORIGINATOR, + "state": state, + } + + if workspace_id: + params["allowed_workspace_id"] = workspace_id + + return f"{self.AUTHORIZATION_ENDPOINT}?{urlencode(params)}" + + def _exchange_code_for_tokens(self, code: str, redirect_uri: str, code_verifier: str) -> dict: + response = requests.post( + self.TOKEN_ENDPOINT, + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": self.CLIENT_ID, + "code_verifier": code_verifier, + }, + timeout=30, + ) + + if not response.ok: + raise RuntimeError(f"OAuth token exchange failed: HTTP {response.status_code} {response.text[:300]}") + + data = response.json() + required_keys = ("id_token", "access_token", "refresh_token") + if not all(data.get(key) for key in required_keys): + raise RuntimeError("OAuth token exchange returned incomplete credentials") + + return data + + def _refresh_token(self, refresh_token: str) -> dict: + response = requests.post( + self.TOKEN_ENDPOINT, + data={ + "client_id": self.CLIENT_ID, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + timeout=30, + ) + + if not response.ok: + raise RuntimeError(f"Token refresh failed: HTTP {response.status_code} {response.text[:300]}") + + data = response.json() + access_token = str(data.get("access_token") or "").strip() + id_token = str(data.get("id_token") or "").strip() + next_refresh = str(data.get("refresh_token") or refresh_token).strip() + + if not access_token or not id_token: + raise RuntimeError("Token refresh returned incomplete credentials") + + return { + "id_token": id_token, + "access_token": access_token, + "refresh_token": next_refresh, + } + + def _build_auth_record(self, tokens_data: dict) -> AuthRecord: + claims = _extract_org_context(tokens_data["id_token"]) + return AuthRecord( + provider="openai", + tokens=OAuthTokens( + id_token=tokens_data["id_token"], + access_token=tokens_data["access_token"], + refresh_token=tokens_data["refresh_token"], + account_id=claims.get("chatgpt_account_id"), + organization_id=claims.get("organization_id"), + project_id=claims.get("project_id"), + ), + last_refresh=_utc_now(), + ) + + def _load_auth_record(self) -> Optional[AuthRecord]: + if not self.auth_path.exists(): + return None + try: + with open(self.auth_path, "r") as f: + data = json.load(f) + return AuthRecord.from_dict(data) + except Exception as e: + logger.warning(f"Failed to load auth record: {e}") + return None + + def _save_auth_record(self, record: AuthRecord) -> None: + try: + self.auth_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.auth_path, "w") as f: + json.dump(record.to_dict(), f, indent=2) + os.chmod(self.auth_path, stat.S_IRUSR | stat.S_IWUSR) + except Exception as e: + logger.warning(f"Failed to save auth record: {e}") + + def _parse_manual_callback(self, value: str) -> dict: + from urllib.parse import parse_qs, urlparse + + text = (value or "").strip() + if not text: + raise ValueError("Missing OAuth callback URL or code/state pair") + + if "://" in text: + parsed = urlparse(text) + params = parse_qs(parsed.query) + return {key: values[-1] for key, values in params.items() if values} + + if "#" in text: + code, state = text.split("#", 1) + return {"code": code.strip(), "state": state.strip()} + + raise ValueError("Could not parse OAuth callback input") + + def login( + self, + *, + workspace_id: Optional[str] = None, + open_browser: bool = True, + timeout_seconds: int = 300, + prompt_for_redirect: Optional[Callable[[str], str]] = None, + ) -> bool: + """ + Initiate OpenAI OAuth login flow. + """ + with self._lock: + verifier, challenge = _pkce_pair() + state = _b64url(secrets.token_bytes(24)) + + event = threading.Event() + server, port = self._create_callback_server(event) + redirect_uri = f"http://localhost:{port}/auth/callback" + + auth_url = self._build_auth_url(challenge, redirect_uri, state, workspace_id) + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + time.sleep(0.5) + + logger.info(f"OAuth server started on port {port}") + logger.info(f"Callback URL: {redirect_uri}") + + try: + if open_browser: + webbrowser.open(auth_url) + + logger.info("Waiting for OAuth callback...") + + if not event.wait(timeout_seconds): + if prompt_for_redirect is None: + logger.warning("OAuth callback timeout") + raise TimeoutError("Timed out waiting for OpenAI OAuth callback") + else: + manual = prompt_for_redirect(auth_url) + params = self._parse_manual_callback(manual) + else: + params = dict(getattr(server, "result", {}) or {}) + finally: + try: + server.shutdown() + server.server_close() + except Exception: + pass + try: + thread.join(timeout=2) + except Exception: + pass + + if params.get("state") != state: + raise ValueError("OAuth callback state mismatch") + + if params.get("error"): + detail = str(params.get("error_description") or params["error"]) + raise RuntimeError(f"OpenAI OAuth failed: {detail}") + + code = str(params.get("code") or "").strip() + if not code: + raise ValueError("OAuth callback did not include a code") + + tokens_data = self._exchange_code_for_tokens(code, redirect_uri, verifier) + record = self._build_auth_record(tokens_data) + self._save_auth_record(record) + + logger.info("OpenAI OAuth login successful") + return True + + def refresh(self) -> bool: + """Refresh the access token.""" + auth = self._load_auth_record() + if not auth or not auth.tokens.refresh_token: + raise ValueError("No refresh token available") + + refreshed = self._refresh_token(auth.tokens.refresh_token) + record = self._build_auth_record(refreshed) + self._save_auth_record(record) + return True + + def ensure_access_token(self, refresh_if_needed: bool = True) -> Optional[str]: + """Get a valid access token.""" + auth = self._load_auth_record() + if not auth: + return None + + access_token = auth.tokens.access_token + refresh_token = auth.tokens.refresh_token + + if refresh_if_needed and refresh_token and (not access_token or _token_expired(access_token)): + self.refresh() + auth = self._load_auth_record() + access_token = auth.tokens.access_token if auth else None + + return access_token + + def ensure_access_token_with_codex_fallback( + self, + *, + refresh_if_needed: bool = True, + import_codex_if_missing: bool = True, + ) -> Optional[str]: + """Return a usable access token, importing Codex CLI auth when available.""" + access_token = self.ensure_access_token(refresh_if_needed=refresh_if_needed) + if access_token or not import_codex_if_missing: + return access_token + + imported = import_from_codex_cli() + if not imported: + return None + + return self.ensure_access_token(refresh_if_needed=refresh_if_needed) + + def build_codex_auth_context( + self, + *, + refresh_if_needed: bool = True, + import_codex_if_missing: bool = True, + ) -> Optional[dict]: + """Build auth context for Codex-specific OAuth calls. + + This is intentionally separate from generic OpenAI API auth. The returned + context is only meant for the Codex/ChatGPT backend path. + """ + access_token = self.ensure_access_token_with_codex_fallback( + refresh_if_needed=refresh_if_needed, + import_codex_if_missing=import_codex_if_missing, + ) + if not access_token: + return None + + auth = self._load_auth_record() + tokens = auth.tokens if auth else OAuthTokens("", access_token, "") + + return { + "base_url": f"{OPENAI_CODEX_BASE_URL}/codex", + "access_token": access_token, + "account_id": tokens.account_id, + "organization_id": tokens.organization_id, + "project_id": tokens.project_id, + } + + def _status_from_auth_record(self, auth: AuthRecord | None) -> OAuthStatus: + """Build status from a loaded auth record without refreshing tokens.""" + if not auth or not auth.tokens.access_token: + return OAuthStatus(authenticated=False, provider="openai") + + access_token = auth.tokens.access_token + id_token = auth.tokens.id_token + + token_expires_at = _extract_token_exp(id_token) if id_token else None + if token_expires_at is None: + token_expires_at = _extract_token_exp(access_token) if access_token else None + + return OAuthStatus( + authenticated=bool(access_token), + email=_extract_email(id_token) if id_token else "", + organization_id=auth.tokens.organization_id, + project_id=auth.tokens.project_id, + token_expires_at=token_expires_at, + provider="openai", + ) + + def peek_status(self) -> OAuthStatus: + """Read current OAuth status from disk without refreshing tokens.""" + auth = self._load_auth_record() + return self._status_from_auth_record(auth) + + def get_status(self) -> OAuthStatus: + """Get current OAuth status, refreshing expired tokens when possible.""" + auth = self._load_auth_record() + if not auth or not auth.tokens.access_token: + return OAuthStatus(authenticated=False, provider="openai") + + access_token = auth.tokens.access_token + if access_token and _token_expired(access_token): + refresh_token = auth.tokens.refresh_token + if refresh_token: + try: + self.refresh() + auth = self._load_auth_record() + except Exception as e: + logger.warning(f"Token refresh failed: {e}") + + return self._status_from_auth_record(auth) + + def logout(self) -> None: + """Clear OAuth credentials and revoke tokens on OpenAI server.""" + auth = self._load_auth_record() + + if auth and auth.tokens.access_token: + try: + requests.post( + f"{OPENAI_AUTH_ISSUER}/oauth/revoke", + data={"token": auth.tokens.access_token}, + timeout=10, + ) + except Exception as e: + logger.warning(f"Failed to revoke access token: {e}") + + try: + requests.post( + f"{OPENAI_AUTH_ISSUER}/oauth/revoke", + data={"token": auth.tokens.refresh_token}, + timeout=10, + ) + except Exception as e: + logger.warning(f"Failed to revoke refresh token: {e}") + + if self.auth_path.exists(): + self.auth_path.unlink() + + +CODEX_CLI_AUTH_PATH = Path.home() / ".codex" / "auth.json" + + +def import_from_codex_cli() -> bool: + """Import authentication from Codex CLI. + + Reads the existing Codex CLI authentication and converts it to our format. + This allows PantheonOS to use Codex CLI's existing login session. + + Returns: + True if import successful, False otherwise + """ + import json + from datetime import datetime, timezone + + if not CODEX_CLI_AUTH_PATH.exists(): + logger.warning(f"Codex CLI auth file not found: {CODEX_CLI_AUTH_PATH}") + return False + + try: + with open(CODEX_CLI_AUTH_PATH, "r") as f: + codex_data = json.load(f) + + tokens_data = codex_data.get("tokens", {}) + if not tokens_data: + logger.warning("Codex CLI auth file has no tokens") + return False + + access_token = tokens_data.get("access_token") + id_token = tokens_data.get("id_token") + refresh_token = tokens_data.get("refresh_token") + + if not access_token: + logger.warning("Codex CLI has no access token") + return False + + account_id = tokens_data.get("account_id") + + auth_record = AuthRecord( + provider="openai", + tokens=OAuthTokens( + id_token=id_token or "", + access_token=access_token, + refresh_token=refresh_token or "", + account_id=account_id, + ), + last_refresh=datetime.now(timezone.utc).isoformat(), + email=_extract_email(id_token) if id_token else "", + ) + + provider = OpenAIOAuthProvider() + provider._save_auth_record(auth_record) + + logger.info(f"Successfully imported Codex CLI authentication") + return True + + except Exception as e: + logger.error(f"Failed to import Codex CLI auth: {e}") + return False + + +def get_openai_oauth_provider() -> OpenAIOAuthProvider: + """Get the OpenAI OAuth provider.""" + return OpenAIOAuthProvider() diff --git a/pantheon/chatroom/room.py b/pantheon/chatroom/room.py index 488d54665..99ce8f30a 100644 --- a/pantheon/chatroom/room.py +++ b/pantheon/chatroom/room.py @@ -2141,7 +2141,7 @@ async def compress_chat(self, chat_id: str) -> dict: return {"success": False, "message": str(e)} def _validate_model_provider(self, model: str) -> tuple[bool, str]: - """Validate that the provider for a model has a valid API key. + """Validate that the provider for a model has usable credentials. Args: model: Model name or tag. @@ -2169,6 +2169,25 @@ def _validate_model_provider(self, model: str) -> tuple[bool, str]: } provider = provider_aliases.get(provider, provider) + if provider == "codex": + try: + from pantheon.auth.openai_auth_strategy import is_oauth_auth_enabled + from pantheon.auth.openai_provider import get_openai_oauth_provider + + if not is_oauth_auth_enabled(): + return False, "Provider 'codex' disabled by auth.openai settings" + + oauth_provider = get_openai_oauth_provider() + context = oauth_provider.build_codex_auth_context( + refresh_if_needed=True, + import_codex_if_missing=True, + ) + if context and context.get("access_token"): + return True, "" + return False, "Provider 'codex' not available (missing OAuth login)" + except Exception: + return False, "Provider 'codex' not available (missing OAuth login)" + if provider not in available: return False, f"Provider '{provider}' not available (missing API key)" diff --git a/pantheon/factory/templates/settings.json b/pantheon/factory/templates/settings.json index 9b85093b9..779a092a3 100644 --- a/pantheon/factory/templates/settings.json +++ b/pantheon/factory/templates/settings.json @@ -107,6 +107,19 @@ "SCRAPER_API_KEY": null, // ScraperAPI key "HUGGINGFACE_TOKEN": null }, + // ===== OpenAI Authentication Strategy ===== + "auth": { + "openai": { + // auto | prefer_api_key | prefer_oauth | api_key_only | oauth_only + "mode": "auto", + "enable_api_key": true, + "enable_oauth": true, + // Keep false by default: codex transport is OAuth-specific today. + "allow_codex_fallback_to_api_key": false, + // Keep false by default: OAuth tokens are not generic OpenAI API keys. + "allow_openai_api_fallback_to_oauth": false + } + }, // ===== Knowledge/RAG Configuration ===== "knowledge": { // Knowledge base storage path @@ -240,4 +253,4 @@ // Jitter factor (0.0-1.0) to randomize delay and avoid thundering herd "jitter": 0.5 } -} \ No newline at end of file +} diff --git a/pantheon/repl/__init__.py b/pantheon/repl/__init__.py index 7d29cf440..4d04b5422 100644 --- a/pantheon/repl/__init__.py +++ b/pantheon/repl/__init__.py @@ -3,6 +3,11 @@ # Prevent litellm from making blocking network calls to GitHub on startup os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") -from .core import Repl +__all__ = ["Repl"] -__all__ = ["Repl"] \ No newline at end of file + +def __getattr__(name: str): + if name == "Repl": + from .core import Repl + return Repl + raise AttributeError(name) diff --git a/pantheon/repl/core.py b/pantheon/repl/core.py index 51174c0ec..d772f3d1c 100644 --- a/pantheon/repl/core.py +++ b/pantheon/repl/core.py @@ -1056,6 +1056,12 @@ async def _handle_message_or_command(self, current_message: str): self._handle_keys_command(args) return + # OAuth command + elif cmd_lower.startswith("/oauth"): + args = cmd[7:].strip() # Handles both "/oauth" and "/oauth login" + await self._handle_oauth_command(args) + return + # Verbose mode command elif cmd_lower in ["/verbose", "/v"]: self.set_display_mode(DisplayMode.VERBOSE) @@ -2238,6 +2244,213 @@ def _handle_keys_command(self, args: str): reset_model_selector() self.console.print(f"[green]\u2713[/green] {display_name} ({env_var}) saved to ~/.pantheon/.env") + async def _handle_oauth_command(self, args: str): + """Handle /oauth command - manage OAuth authentication. + + Usage: + /oauth login [provider] - Start OAuth login flow (default: openai) + /oauth status [provider] - Show OAuth authentication status + /oauth logout [provider] - Clear OAuth credentials + /oauth import-codex - Import authentication from Codex CLI + /oauth prefs - Show API key/OAuth routing preferences + /oauth mode - Set auth mode + /oauth enable + /oauth disable + """ + from pantheon.auth.oauth_manager import get_oauth_manager + from pantheon.auth.openai_auth_strategy import ( + VALID_OPENAI_AUTH_MODES, + ) + from pantheon.repl.setup_wizard import ( + _save_openai_auth_settings_to_settings, + get_openai_auth_summary_state, + ) + from pantheon.settings import get_settings + import asyncio + import os + + parts = args.lower().strip().split() if args else [] + subcommand = parts[0] if parts else "status" + provider = parts[1] if len(parts) > 1 else None + + oauth_manager = get_oauth_manager() + + if subcommand == "list": + self.console.print() + self.console.print("[bold]Available OAuth Providers[/bold]") + self.console.print() + + providers = oauth_manager.list_providers() + default_provider = oauth_manager.default_provider + + for p in providers: + marker = " (default)" if p == default_provider else "" + self.console.print(f" • {p}{marker}") + + self.console.print() + self.console.print("[dim]Usage: /oauth login [/dim]") + self.console.print() + + elif subcommand == "login": + self.console.print() + provider_name = provider or "openai" + self.console.print(f"[bold]{provider_name.title()} OAuth Login[/bold]") + self.console.print("[dim]A browser window will open for you to authenticate.[/dim]") + self.console.print() + + try: + loop = asyncio.get_event_loop() + success = await loop.run_in_executor( + None, + lambda: oauth_manager.login(provider) + ) + + if success: + status = oauth_manager.get_status(provider) + self.console.print(f"[green]✓ {provider_name.title()} OAuth login successful![/green]") + self.console.print("[dim]This logs in your OpenAI account, but does not replace OPENAI_API_KEY for OpenAI API model calls.[/dim]") + if status.email: + self.console.print(f" Email: {status.email}") + if status.organization_id: + self.console.print(f" Organization ID: {status.organization_id}") + if status.project_id: + self.console.print(f" Project ID: {status.project_id}") + self.console.print() + else: + self.console.print(f"[red]✗ {provider_name.title()} OAuth login failed[/red]") + self.console.print("[dim]Please try again or check your internet connection.[/dim]") + self.console.print() + except Exception as e: + self.console.print(f"[red]✗ OAuth login error: {e}[/red]") + self.console.print() + + elif subcommand == "status": + self.console.print() + provider_name = provider or oauth_manager.default_provider + self.console.print(f"[bold]{provider_name.title()} OAuth Status[/bold]") + self.console.print() + + try: + status = oauth_manager.get_status(provider) + + if status.authenticated: + self.console.print("[green]✓ Authenticated[/green]") + if status.email: + self.console.print(f" Email: {status.email}") + if status.organization_id: + self.console.print(f" Organization: {status.organization_id}") + if status.project_id: + self.console.print(f" Project: {status.project_id}") + if status.token_expires_at: + self.console.print(f" Token Expires: {status.token_expires_at}") + else: + self.console.print("[yellow]Not authenticated[/yellow]") + self.console.print("[dim]Use '/oauth login openai' to authenticate.[/dim]") + self.console.print() + except Exception as e: + self.console.print(f"[red]✗ Failed to get OAuth status: {e}[/red]") + self.console.print() + + elif subcommand == "logout": + self.console.print() + provider_name = provider or oauth_manager.default_provider + self.console.print(f"[bold]{provider_name.title()} OAuth Logout[/bold]") + self.console.print() + + try: + oauth_manager.logout(provider) + self.console.print(f"[green]✓ {provider_name.title()} OAuth credentials cleared[/green]") + self.console.print("[dim]Use '/oauth login openai' to authenticate again.[/dim]") + self.console.print() + except Exception as e: + self.console.print(f"[red]✗ Failed to logout: {e}[/red]") + self.console.print() + + elif subcommand == "import-codex": + self.console.print() + self.console.print("[bold]Import from Codex CLI[/bold]") + self.console.print("[dim]Reading existing Codex CLI authentication...[/dim]") + self.console.print() + + try: + from pantheon.auth.openai_provider import import_from_codex_cli + success = import_from_codex_cli() + + if success: + status = oauth_manager.get_status("openai") + self.console.print("[green]✓ Successfully imported Codex CLI authentication![/green]") + self.console.print("[dim]Imported OAuth credentials are kept for account login/status only, not used as an OpenAI API key.[/dim]") + if status.email: + self.console.print(f" Email: {status.email}") + self.console.print() + self.console.print("[dim]You can now manage the linked OpenAI account from PantheonOS.[/dim]") + self.console.print() + else: + self.console.print("[red]✗ Failed to import Codex CLI authentication[/red]") + self.console.print("[dim]Make sure you have run 'codex login' first.[/dim]") + self.console.print() + except Exception as e: + self.console.print(f"[red]✗ Import error: {e}[/red]") + self.console.print() + + elif subcommand == "prefs": + self.console.print() + self.console.print("[bold]OpenAI Authentication Preferences[/bold]") + self.console.print() + state = get_openai_auth_summary_state() + self.console.print(f" Mode: {state['mode']}") + self.console.print(f" API Key Enabled: {state['enable_api_key']}") + self.console.print(f" OAuth Enabled: {state['enable_oauth']}") + self.console.print(f" API Key Present: {state['api_key_present']}") + self.console.print(f" OAuth Authenticated: {state['oauth_authenticated']}") + self.console.print(f" Effective API Key Routing: {state['effective_api_key_enabled']}") + self.console.print(f" Effective OAuth Routing: {state['effective_oauth_enabled']}") + self.console.print() + self.console.print("[dim]Modes: auto, prefer_api_key, prefer_oauth, api_key_only, oauth_only[/dim]") + self.console.print() + + elif subcommand == "mode": + mode = parts[1] if len(parts) > 1 else "" + if mode not in VALID_OPENAI_AUTH_MODES: + self.console.print("[yellow]Usage: /oauth mode [/yellow]") + self.console.print() + return + + if _save_openai_auth_settings_to_settings({"mode": mode}): + get_settings().reload() + self.console.print(f"[green]✓ OpenAI auth mode set to {mode}[/green]") + else: + self.console.print("[red]✗ Failed to update auth mode[/red]") + self.console.print() + + elif subcommand in {"enable", "disable"}: + target = parts[1] if len(parts) > 1 else "" + enabled = subcommand == "enable" + key_map = { + "api-key": "enable_api_key", + "apikey": "enable_api_key", + "api_key": "enable_api_key", + "oauth": "enable_oauth", + } + setting_key = key_map.get(target) + if not setting_key: + self.console.print("[yellow]Usage: /oauth enable or /oauth disable [/yellow]") + self.console.print() + return + + if _save_openai_auth_settings_to_settings({setting_key: enabled}): + get_settings().reload() + verb = "enabled" if enabled else "disabled" + self.console.print(f"[green]✓ {target} {verb} for OpenAI auth routing[/green]") + else: + self.console.print("[red]✗ Failed to update auth preference[/red]") + self.console.print() + + else: + self.console.print(f"[red]Unknown subcommand: {subcommand}[/red]") + self.console.print("[dim]Use /oauth login, /oauth status, /oauth logout, /oauth import-codex, /oauth prefs, /oauth mode, /oauth enable, or /oauth disable[/dim]") + self.console.print() + async def _handle_model_command(self, args: str): """Handle /model command - list or set model.""" if not args: diff --git a/pantheon/repl/setup_wizard.py b/pantheon/repl/setup_wizard.py index b979c6a6e..811e92643 100644 --- a/pantheon/repl/setup_wizard.py +++ b/pantheon/repl/setup_wizard.py @@ -17,8 +17,11 @@ from dataclasses import dataclass from typing import Optional +from pantheon.auth.oauth_manager import get_oauth_manager from pantheon.utils.model_selector import PROVIDER_API_KEYS, CUSTOM_ENDPOINT_ENVS, CustomEndpointConfig from pantheon.utils.log import logger +from pantheon.settings import load_jsonc +from pantheon.auth.openai_auth_strategy import summarize_openai_auth_state # ============ Data Classes for Better Readability ============ @@ -36,6 +39,7 @@ class ProviderMenuEntry: # Providers shown in the wizard/keys menu PROVIDER_MENU = [ ProviderMenuEntry("openai", "OpenAI", "OPENAI_API_KEY"), + ProviderMenuEntry("openai_oauth", "OpenAI (OAuth)", None), # OAuth doesn't require API key ProviderMenuEntry("anthropic", "Anthropic", "ANTHROPIC_API_KEY"), ProviderMenuEntry("gemini", "Google Gemini", "GEMINI_API_KEY"), ProviderMenuEntry("google", "Google AI", "GOOGLE_API_KEY"), @@ -65,13 +69,45 @@ class ProviderMenuEntry: for config in CUSTOM_ENDPOINT_ENVS.values() ] + +def _get_openai_oauth_status(): + try: + manager = get_oauth_manager() + provider = manager.get_provider("openai") + if hasattr(provider, "peek_status"): + return provider.peek_status() + return manager.get_status("openai") + except Exception: + return None + + +def get_openai_auth_summary_state() -> dict: + oauth_status = _get_openai_oauth_status() + return summarize_openai_auth_state( + api_key_present=bool(os.environ.get("OPENAI_API_KEY")), + oauth_authenticated=bool(oauth_status and oauth_status.authenticated), + ) + + +def _render_openai_auth_summary(console, title: str = "OpenAI Auth Status", state: dict | None = None): + state = state or get_openai_auth_summary_state() + + console.print() + console.print(f"[bold]{title}[/bold]") + console.print(f" API Key: {'configured' if state['api_key_present'] else 'not configured'}") + console.print(f" OAuth: {'authenticated' if state['oauth_authenticated'] else 'not authenticated'}") + console.print(f" Mode: {state['mode']}") + console.print( + f" Routing: api_key={'on' if state['effective_api_key_enabled'] else 'off'}, " + f"oauth={'on' if state['effective_oauth_enabled'] else 'off'}" + ) + console.print() def check_and_run_setup(): - """Check if any LLM provider API keys are set; launch wizard if none found. + """Check if any callable LLM provider credentials are set; launch wizard if none found. Called at startup before the event loop starts (sync context). Also checks for universal LLM_API_KEY (custom API endpoint) and custom endpoint keys (CUSTOM_*_API_KEY). - Skips the wizard if: - Any API key is already configured - SKIP_SETUP_WIZARD environment variable is set @@ -90,6 +126,20 @@ def check_and_run_setup(): if os.environ.get(config.api_key_env, ""): return + # Check OAuth providers + try: + oauth_manager = get_oauth_manager() + for provider_name in oauth_manager.list_providers(): + provider = oauth_manager.get_provider(provider_name) + if hasattr(provider, "peek_status"): + status = provider.peek_status() + else: + status = oauth_manager.get_status(provider_name) + if status and status.authenticated: + return + except Exception: + pass + # Check legacy universal LLM_API_KEY (with deprecation warning) if os.environ.get("LLM_API_KEY", ""): if os.environ.get("LLM_API_BASE", ""): @@ -100,7 +150,7 @@ def check_and_run_setup(): ) return - # No API keys found - launch wizard + # No API keys or OAuth found - launch wizard run_setup_wizard() @@ -127,6 +177,7 @@ def run_setup_wizard(standalone: bool = False): border_style="cyan", ) ) + _render_openai_auth_summary(console, "Current OpenAI Auth Status") configured_any = False @@ -144,8 +195,12 @@ def run_setup_wizard(standalone: bool = False): # Show provider menu console.print("\nStandard Providers:") for i, entry in enumerate(PROVIDER_MENU, 1): - already_set = " [green](configured)[/green]" if os.environ.get(entry.env_var, "") else "" - console.print(f" [cyan][{i}][/cyan] {entry.display_name:<20} ({entry.env_var}){already_set}") + # Handle OAuth providers which don't have env_var + if entry.env_var is None: + already_set = "" + else: + already_set = " [green](configured)[/green]" if os.environ.get(entry.env_var, "") else "" + console.print(f" [cyan][{i}][/cyan] {entry.display_name:<20} ({entry.env_var or 'OAuth'}){already_set}") console.print() console.print("[dim] Prefix with 'd' to delete, e.g. d0, d1,d3[/dim]") console.print() @@ -202,8 +257,19 @@ def run_setup_wizard(standalone: bool = False): for idx in delete_standard_indices: entry = PROVIDER_MENU[idx] - _remove_key_from_env_file(entry.env_var) - console.print(f"[green]\u2713 {entry.display_name} ({entry.env_var}) removed[/green]") + + # Special handling for OAuth providers + if entry.provider_key == "openai_oauth": + try: + oauth_manager = get_oauth_manager() + oauth_manager.logout("openai") + console.print(f"[green]✓ {entry.display_name} credentials cleared[/green]") + except Exception as e: + logger.warning(f"Failed to clear OAuth credentials: {e}") + console.print(f"[yellow]Failed to clear {entry.display_name}: {e}[/yellow]") + else: + _remove_key_from_env_file(entry.env_var) + console.print(f"[green]\u2713 {entry.display_name} ({entry.env_var}) removed[/green]") if (delete_legacy_custom or delete_custom_indices or delete_standard_indices) and not standard_indices and not custom_indices and not has_legacy_custom: console.print() @@ -303,6 +369,38 @@ def run_setup_wizard(standalone: bool = False): # Collect API keys for selected standard providers for idx in standard_indices: entry = PROVIDER_MENU[idx] + + # Special handling for OAuth providers (no API key needed) + if entry.provider_key == "openai_oauth": + console.print(f"\n[bold]Configure {entry.display_name}[/bold]") + console.print("[dim]A browser window will open so you can authenticate with OpenAI.[/dim]") + console.print("[dim]This enables Codex OAuth transport. Standard OpenAI API model calls still require an API key or compatible endpoint.[/dim]") + + try: + oauth_manager = get_oauth_manager() + success = oauth_manager.login("openai") + + if success: + status = oauth_manager.get_status("openai") + console.print("[green]✓ OpenAI OAuth login successful[/green]") + if status.email: + console.print(f" Email: {status.email}") + if status.organization_id: + console.print(f" Organization: {status.organization_id}") + if status.project_id: + console.print(f" Project: {status.project_id}") + configured_any = True + else: + console.print("[red]✗ OpenAI OAuth login failed[/red]") + console.print("[dim]You can retry later with '/oauth login openai' in the REPL.[/dim]") + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]OAuth login cancelled.[/yellow]") + except Exception as e: + logger.warning(f"OpenAI OAuth login from setup wizard failed: {e}") + console.print(f"[red]✗ OpenAI OAuth login error: {e}[/red]") + console.print("[dim]You can retry later with '/oauth login openai' in the REPL.[/dim]") + continue + console.print(f"\n[bold]Enter API key for {entry.display_name}[/bold]") try: api_key = pt_prompt(f"{entry.env_var}: ", is_password=True) @@ -330,7 +428,9 @@ def run_setup_wizard(standalone: bool = False): if configured_any: env_path = Path.home() / ".pantheon" / ".env" - console.print(f"\n[green]\u2713 API keys saved to {env_path}[/green]") + console.print(f"\n[green]\u2713 Provider credentials updated[/green]") + console.print(f"[dim]Environment file: {env_path}[/dim]") + _render_openai_auth_summary(console, "Final OpenAI Auth Status") if not standalone: console.print(" Starting Pantheon...\n") else: @@ -492,3 +592,39 @@ def _remove_custom_model_from_settings(provider_key: str): break except Exception as e: logger.warning(f"Failed to remove custom model from settings.json: {e}") + + +def _ensure_user_settings_file() -> Path | None: + settings_path = Path.home() / ".pantheon" / "settings.json" + if settings_path.exists(): + return settings_path + + try: + settings_path.parent.mkdir(parents=True, exist_ok=True) + template = Path(__file__).parent.parent / "factory" / "templates" / "settings.json" + if template.exists(): + shutil.copy(template, settings_path) + logger.debug(f"Created {settings_path} from factory template") + return settings_path + except Exception as e: + logger.warning(f"Failed to create user settings.json: {e}") + return None + + +def _save_openai_auth_settings_to_settings(updates: dict): + """Persist auth.openai preferences to ~/.pantheon/settings.json.""" + settings_path = _ensure_user_settings_file() + if settings_path is None: + return False + + try: + data = load_jsonc(settings_path) + auth = data.setdefault("auth", {}) + openai = auth.setdefault("openai", {}) + openai.update(updates) + settings_path.write_text(json.dumps(data, indent=4), encoding="utf-8") + logger.debug(f"Updated auth.openai settings in {settings_path}") + return True + except Exception as e: + logger.warning(f"Failed to update auth.openai settings: {e}") + return False diff --git a/pantheon/toolsets/knowledge/knowledge_manager.py b/pantheon/toolsets/knowledge/knowledge_manager.py index 509b0bcaa..2a8625499 100644 --- a/pantheon/toolsets/knowledge/knowledge_manager.py +++ b/pantheon/toolsets/knowledge/knowledge_manager.py @@ -88,12 +88,18 @@ def _create_llm(): from llama_index.llms.openai import OpenAI from pantheon.settings import get_settings from pantheon.utils.llm_providers import get_litellm_proxy_kwargs + from pantheon.auth.oauth_manager import get_oauth_token settings = get_settings() + # Prefer OAuth token, fall back to API key + api_key = get_oauth_token("openai", refresh_if_needed=True) + if not api_key: + api_key = settings.get_api_key("OPENAI_API_KEY") + llm_kwargs = { "model": "gpt-4o-mini", "temperature": 0.1, - "api_key": settings.get_api_key("OPENAI_API_KEY"), + "api_key": api_key, } api_base = settings.get_api_key("OPENAI_API_BASE") if api_base: diff --git a/pantheon/utils/llm.py b/pantheon/utils/llm.py index 2d9c72c4f..9a22be907 100644 --- a/pantheon/utils/llm.py +++ b/pantheon/utils/llm.py @@ -6,9 +6,58 @@ from copy import deepcopy from typing import Any, Callable +from pantheon.auth.openai_auth_strategy import ( + is_api_key_auth_enabled, + is_oauth_auth_enabled, +) from .log import logger from .misc import run_func + +def _get_openai_api_key() -> str | None: + """Get OpenAI API key from environment. + + Returns: + API key string, or None if not available + """ + import os + if not is_api_key_auth_enabled(): + return None + return os.environ.get("OPENAI_API_KEY") + + +def _get_codex_oauth_client_kwargs() -> dict[str, Any] | None: + """Return dedicated client kwargs for Codex OAuth transport when available.""" + if not is_oauth_auth_enabled(): + return None + try: + from pantheon.auth.openai_provider import get_openai_oauth_provider + + provider = get_openai_oauth_provider() + context = provider.build_codex_auth_context( + refresh_if_needed=True, + import_codex_if_missing=True, + ) + if not context: + return None + + default_headers: dict[str, str] = {} + if context.get("account_id"): + default_headers["ChatGPT-Account-Id"] = str(context["account_id"]) + if context.get("organization_id"): + default_headers["OpenAI-Organization"] = str(context["organization_id"]) + + client_kwargs: dict[str, Any] = { + "base_url": str(context["base_url"]), + "api_key": str(context["access_token"]), + } + if default_headers: + client_kwargs["default_headers"] = default_headers + return client_kwargs + except Exception as exc: + logger.debug(f"[CODEX_OAUTH] Failed to build Codex OAuth client config: {exc}") + return None + _PATTERN_BASE64_DATA_URI = re.compile( r"data:image/([a-zA-Z0-9+-]+);base64,([A-Za-z0-9+/=]+)" ) @@ -45,11 +94,13 @@ async def acompletion_openai( ): from openai import NOT_GIVEN, APIConnectionError, AsyncOpenAI + api_key = _get_openai_api_key() + # Create client with custom base_url if provided if base_url: - client = AsyncOpenAI(base_url=base_url) + client = AsyncOpenAI(base_url=base_url, api_key=api_key) else: - client = AsyncOpenAI() + client = AsyncOpenAI(api_key=api_key) chunks = [] _tools = tools or NOT_GIVEN _pcall = (tools is not None) or NOT_GIVEN @@ -234,6 +285,7 @@ async def acompletion_responses( base_url: str | None = None, model_params: dict | None = None, num_retries: int = 3, + codex_oauth_transport: bool = False, ) -> dict: """Call OpenAI Responses API with streaming. @@ -245,15 +297,24 @@ async def acompletion_responses( # ========== Build client ========== proxy_kwargs = get_litellm_proxy_kwargs() + api_key = _get_openai_api_key() + codex_oauth_kwargs = _get_codex_oauth_client_kwargs() if codex_oauth_transport else None + if proxy_kwargs: client = AsyncOpenAI( base_url=proxy_kwargs["api_base"], api_key=proxy_kwargs["api_key"] ) + elif codex_oauth_kwargs: + logger.info( + f"[RESPONSES_API] Using Codex OAuth transport | model={model} | " + f"base_url={codex_oauth_kwargs['base_url']}" + ) + client = AsyncOpenAI(**codex_oauth_kwargs) elif base_url: - client = AsyncOpenAI(base_url=base_url) + client = AsyncOpenAI(base_url=base_url, api_key=api_key) else: - client = AsyncOpenAI() + client = AsyncOpenAI(api_key=api_key) # ========== Convert inputs ========== instructions, input_items = _convert_messages_to_responses_input(messages) @@ -266,8 +327,12 @@ async def acompletion_responses( "input": input_items, "stream": True, } + if codex_oauth_kwargs: + kwargs["store"] = False if instructions is not None: kwargs["instructions"] = instructions + elif codex_oauth_kwargs: + kwargs["instructions"] = "You are Codex." if converted_tools is not None: kwargs["tools"] = converted_tools if response_format is not None: diff --git a/pantheon/utils/llm_providers.py b/pantheon/utils/llm_providers.py index d9272cd40..64a88460b 100644 --- a/pantheon/utils/llm_providers.py +++ b/pantheon/utils/llm_providers.py @@ -13,6 +13,11 @@ from typing import Any, Callable, Optional, NamedTuple from dataclasses import dataclass +from pantheon.auth.openai_auth_strategy import ( + get_openai_auth_settings, + is_api_key_auth_enabled, + should_use_codex_oauth_transport, +) from .misc import run_func from .log import logger @@ -105,9 +110,12 @@ def detect_provider(model: str, force_litellm: bool) -> ProviderConfig: compat_base, compat_key_env = OPENAI_COMPATIBLE_PROVIDERS[provider_lower] base_url = os.environ.get(f"{provider_lower.upper()}_API_BASE", compat_base) api_key = os.environ.get(compat_key_env, "") - # Check if it's explicitly openai provider + # Check if it's explicitly openai/codex provider elif provider_lower == "openai": provider_type = ProviderType.OPENAI + elif provider_lower == "codex": + provider_type = ProviderType.OPENAI + model_name = model else: # All other prefixed models go through LiteLLM (zhipu, anthropic, etc.) provider_type = ProviderType.LITELLM @@ -136,7 +144,10 @@ def is_responses_api_model(config: ProviderConfig) -> bool: """ return ( config.provider_type == ProviderType.OPENAI - and "codex" in config.model_name.lower() + and ( + "codex" in config.model_name.lower() + or config.model_name.lower().startswith("codex/") + ) ) @@ -201,6 +212,9 @@ def get_api_key_for_provider(provider: ProviderType) -> Optional[str]: settings = get_settings() provider_lower = provider.value.lower() + if provider == ProviderType.OPENAI and not is_api_key_auth_enabled(): + return None + # 1. Check custom endpoint key first custom_key = f"custom_{provider_lower}" if custom_key in CUSTOM_ENDPOINT_ENVS: @@ -494,8 +508,19 @@ async def call_llm_provider( from .llm import acompletion_responses model_name = config.model_name + use_codex_oauth_transport = should_use_codex_oauth_transport(config.model_name) + + if config.model_name.lower().startswith("codex/") and not use_codex_oauth_transport: + prefs = get_openai_auth_settings() + raise RuntimeError( + "Codex OAuth transport is disabled by auth.openai settings " + f"(mode={prefs.mode}, enable_oauth={prefs.enable_oauth})." + ) + if model_name.startswith("openai/"): model_name = model_name.split("/", 1)[1] + elif model_name.startswith("codex/"): + model_name = model_name.split("/", 1)[1] logger.debug( f"[CALL_LLM_PROVIDER] Using Responses API for model={model_name}" @@ -509,6 +534,7 @@ async def call_llm_provider( process_chunk=process_chunk, base_url=config.base_url, model_params=model_params, + codex_oauth_transport=use_codex_oauth_transport, ) if config.provider_type == ProviderType.OPENAI: diff --git a/pantheon/utils/model_selector.py b/pantheon/utils/model_selector.py index 817ddc7d1..5db0a13b4 100644 --- a/pantheon/utils/model_selector.py +++ b/pantheon/utils/model_selector.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from pantheon.auth.openai_auth_strategy import should_treat_openai_api_key_as_available from .log import logger if TYPE_CHECKING: @@ -65,7 +66,7 @@ class CustomEndpointConfig: # Quality levels map to MODEL LISTS (not single models) for fallback chains # Models within each level are ordered by preference DEFAULT_PROVIDER_MODELS = { - # OpenAI: GPT-5.4 series + # OpenAI: GPT-4o series # https://platform.openai.com/docs/models "openai": { "high": ["openai/gpt-5.4-pro", "openai/gpt-5.4", "openai/gpt-5.2-pro", "openai/gpt-5.2"], @@ -157,7 +158,7 @@ class CustomEndpointConfig: QUALITY_TAGS = {"high", "normal", "low"} # Ultimate fallback model when nothing else works (must be concrete model, not tag) -ULTIMATE_FALLBACK = "openai/gpt-5.4" +ULTIMATE_FALLBACK = "openai/gpt-4o-mini" # Recommended fallback tag for general use FALLBACK_TAG = "low" @@ -216,8 +217,14 @@ def __init__(self, settings: "Settings"): self._detected_provider: str | None = None self._available_providers: set[str] | None = None + def _settings_get(self, key: str, default=None): + """Read a setting when a Settings object is available.""" + if self.settings is None: + return default + return self.settings.get(key, default) + def _get_available_providers(self) -> set[str]: - """Get set of providers with valid API keys (cached).""" + """Get set of providers with valid API credentials for model calls (cached).""" if self._available_providers is not None: return self._available_providers @@ -226,6 +233,8 @@ def _get_available_providers(self) -> set[str]: self._available_providers = set() for provider, env_key in PROVIDER_API_KEYS.items(): + if provider == "openai" and not should_treat_openai_api_key_as_available(): + continue api_key_value = os.environ.get(env_key, "") if api_key_value: self._available_providers.add(provider) @@ -238,7 +247,11 @@ def _get_available_providers(self) -> set[str]: # Universal proxy: LLM_API_KEY makes openai provider available # (most third-party proxies are OpenAI-compatible) # Note: LLM_API_BASE is deprecated, warn user to use custom endpoints instead - if not self._available_providers and os.environ.get("LLM_API_KEY", ""): + if ( + not self._available_providers + and should_treat_openai_api_key_as_available() + and os.environ.get("LLM_API_KEY", "") + ): if os.environ.get("LLM_API_BASE", ""): logger.warning( "LLM_API_BASE is deprecated. Consider using CUSTOM_OPENAI_API_BASE or " @@ -281,7 +294,7 @@ def detect_available_provider(self) -> str | None: return provider_key # 2. Priority: user config > code defaults - priority = self.settings.get( + priority = self._settings_get( "models.provider_priority", DEFAULT_PROVIDER_PRIORITY ) @@ -321,7 +334,7 @@ def _get_provider_models(self, provider: str) -> dict[str, list[str]]: return {} # Try user configuration first - user_config = self.settings.get(f"models.provider_models.{provider}", {}) + user_config = self._settings_get(f"models.provider_models.{provider}", {}) # Get code defaults default_config = DEFAULT_PROVIDER_MODELS.get(provider, {}) @@ -462,7 +475,7 @@ def resolve_model(self, tag: str) -> list[str]: if provider in CUSTOM_ENDPOINT_ENVS: # Find next available non-custom provider available = self._get_available_providers() - priority = self.settings.get("models.provider_priority", DEFAULT_PROVIDER_PRIORITY) + priority = self._settings_get("models.provider_priority", DEFAULT_PROVIDER_PRIORITY) fallback_found = False for fallback_provider in priority: if fallback_provider in available and fallback_provider not in CUSTOM_ENDPOINT_ENVS: @@ -602,7 +615,7 @@ def resolve_image_gen_model(self, quality: str = "normal") -> list[str]: for provider in priority: if provider in available: - user_config = self.settings.get(f"image_gen_models.{provider}", {}) + user_config = self._settings_get(f"image_gen_models.{provider}", {}) provider_models = user_config or DEFAULT_IMAGE_GEN_MODELS.get(provider, {}) models = provider_models.get(quality, []) if models: @@ -621,7 +634,7 @@ def get_provider_info(self) -> dict: "detected_provider": self._detected_provider or self.detect_available_provider(), "available_providers": list(self._get_available_providers()), - "priority": self.settings.get( + "priority": self._settings_get( "models.provider_priority", DEFAULT_PROVIDER_PRIORITY ), } @@ -637,7 +650,7 @@ def list_available_models(self) -> dict: "available_providers": ["openai", "anthropic"], "current_provider": "openai", "models_by_provider": { - "openai": ["openai/gpt-5.4", "openai/gpt-5.2", ...], + "openai": ["openai/gpt-4o", "openai/gpt-4o-mini", ...], "anthropic": ["anthropic/claude-opus-4-5-20251101", ...] }, "supported_tags": ["high", "normal", "low", "vision", ...] diff --git a/pyproject.toml b/pyproject.toml index 2b4a4bbea..89ee316f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "nats-py>=2.10.0", "nkeys", "PyNaCl>=1.5.0", + "PyJWT[crypto]>=2.10.1", "httpx>=0.28.1", "diskcache", "python-frontmatter>=1.1.0", diff --git a/tests/test_backward_compatibility.py b/tests/test_backward_compatibility.py new file mode 100644 index 000000000..1859b71c3 --- /dev/null +++ b/tests/test_backward_compatibility.py @@ -0,0 +1,325 @@ +""" +Backward Compatibility Tests for API Key Authentication + +Tests that OAuth support does NOT break existing API Key authentication. +Focuses on key integration points: ModelSelector, Setup Wizard, and REPL. +""" + +import os +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + + +class TestModelSelectorBackwardCompatibility(unittest.TestCase): + """Test ModelSelector still works with API Key authentication.""" + + def setUp(self): + """Set up test environment.""" + self.original_api_key = os.environ.get("OPENAI_API_KEY") + + def tearDown(self): + """Restore original environment.""" + if self.original_api_key: + os.environ["OPENAI_API_KEY"] = self.original_api_key + else: + os.environ.pop("OPENAI_API_KEY", None) + + def test_api_key_detection(self): + """Test that ModelSelector detects API Key.""" + from pantheon.utils.model_selector import ModelSelector + + os.environ["OPENAI_API_KEY"] = "sk-test123" + + selector = ModelSelector(None) + provider = selector.detect_available_provider() + + # Should detect openai provider via API key + assert provider == "openai" + + def test_model_resolution_with_api_key(self): + """Test that models can be resolved with API key.""" + from pantheon.utils.model_selector import ModelSelector + + os.environ["OPENAI_API_KEY"] = "sk-test123" + + selector = ModelSelector(None) + models = selector.resolve_model("normal") + + # Should return list of models + assert isinstance(models, list) + + def test_api_key_still_has_public_api(self): + """Test that ModelSelector public API is unchanged.""" + from pantheon.utils.model_selector import ModelSelector + + selector = ModelSelector(None) + + # Check public methods exist + assert hasattr(selector, "detect_available_provider") + assert hasattr(selector, "resolve_model") + assert hasattr(selector, "get_provider_info") + assert hasattr(selector, "list_available_models") + + def test_no_oauth_doesnt_break_selector(self): + """Test that missing OAuth doesn't break ModelSelector.""" + from pantheon.utils.model_selector import ModelSelector + + os.environ["OPENAI_API_KEY"] = "sk-test123" + + with patch( + "pantheon.auth.oauth_manager.get_oauth_manager" + ) as mock_oauth: + # Simulate OAuth not available + mock_oauth.side_effect = ImportError("OAuth not configured") + + selector = ModelSelector(None) + + # Should not crash, should still work with API key + provider = selector.detect_available_provider() + assert provider == "openai" + + +class TestSetupWizardBackwardCompatibility(unittest.TestCase): + """Test Setup Wizard still supports API Key authentication.""" + + def test_api_key_option_in_menu(self): + """Test that OpenAI API key option is in Setup Wizard menu.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + api_key_entries = [ + e for e in PROVIDER_MENU if e.provider_key == "openai" + ] + + assert len(api_key_entries) == 1 + assert api_key_entries[0].display_name == "OpenAI" + + def test_api_key_env_var_in_menu(self): + """Test that API Key menu entry has correct env var.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + api_key_entry = next( + (e for e in PROVIDER_MENU if e.provider_key == "openai"), None + ) + + assert api_key_entry is not None + assert api_key_entry.env_var == "OPENAI_API_KEY" + + def test_both_auth_methods_available(self): + """Test that both OAuth and API Key are available.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + provider_keys = [e.provider_key for e in PROVIDER_MENU] + + assert "openai" in provider_keys, "API Key option must be present" + assert "openai_oauth" in provider_keys, "OAuth option must be present" + + def test_menu_structure_preserved(self): + """Test that menu structure is still valid.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + # Should be a list + assert isinstance(PROVIDER_MENU, list) + + # All entries should have required properties + for entry in PROVIDER_MENU: + assert hasattr(entry, "provider_key") + assert hasattr(entry, "display_name") + + def test_existing_oauth_skips_automatic_setup_wizard(self): + """Test that authenticated OAuth counts as existing credentials for startup.""" + from pantheon.auth.oauth_manager import OAuthStatus + from pantheon.repl import setup_wizard + + mock_manager = Mock() + mock_manager.list_providers.return_value = ["openai"] + mock_manager.get_status.return_value = OAuthStatus(authenticated=True, provider="openai") + + with patch("pantheon.repl.setup_wizard.get_oauth_manager", return_value=mock_manager): + with patch("pantheon.repl.setup_wizard.run_setup_wizard") as mock_run: + with patch.dict(os.environ, {}, clear=True): + setup_wizard.check_and_run_setup() + mock_run.assert_not_called() + + +class TestREPLBackwardCompatibility(unittest.TestCase): + """Test REPL commands still work with API Key.""" + + def test_repl_package_exports_repl_symbol(self): + """Test that pantheon.repl exports Repl lazily.""" + import pantheon.repl as repl_pkg + + assert "Repl" in getattr(repl_pkg, "__all__", []) + assert hasattr(repl_pkg, "__getattr__") + + def test_oauth_command_contract_present_in_source(self): + """Test that the REPL source still defines the OAuth command handler.""" + core_path = Path(__file__).resolve().parents[1] / "pantheon" / "repl" / "core.py" + content = core_path.read_text(encoding="utf-8") + + assert "def _handle_oauth_command" in content + assert 'elif cmd_lower.startswith("/oauth")' in content + + def test_setup_wizard_import_no_longer_requires_repl_core(self): + """Test that setup_wizard import does not force pantheon.repl.core import.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + assert isinstance(PROVIDER_MENU, list) + + +class TestAuthenticationCoexistence(unittest.TestCase): + """Test that API Key and OAuth can coexist.""" + + def setUp(self): + """Set up test environment.""" + self.original_api_key = os.environ.get("OPENAI_API_KEY") + + def tearDown(self): + """Clean up.""" + if self.original_api_key: + os.environ["OPENAI_API_KEY"] = self.original_api_key + else: + os.environ.pop("OPENAI_API_KEY", None) + + def test_api_key_with_oauth_token(self): + """Test that both can be present simultaneously.""" + from pantheon.utils.model_selector import ModelSelector + + # Set API key + os.environ["OPENAI_API_KEY"] = "sk-test123" + + with patch( + "pantheon.auth.oauth_manager.get_oauth_manager" + ) as mock_oauth: + mock_mgr = Mock() + mock_mgr.auth_path = Path("oauth_openai.json") + mock_oauth.return_value = mock_mgr + + selector = ModelSelector(None) + + # Should detect OpenAI (works with either auth method) + provider = selector.detect_available_provider() + assert provider == "openai" + + def test_api_key_preferred_when_both_present(self): + """Test API Key detection when both are available.""" + from pantheon.utils.model_selector import ModelSelector + + os.environ["OPENAI_API_KEY"] = "sk-test123" + + selector = ModelSelector(None) + + # Should detect API key (simpler to check first) + provider = selector.detect_available_provider() + assert provider == "openai" + + +class TestNoAuthenticationScenario(unittest.TestCase): + """Test system behavior without any authentication.""" + + def setUp(self): + """Set up test environment.""" + self.original_api_key = os.environ.get("OPENAI_API_KEY") + + def tearDown(self): + """Restore environment.""" + if self.original_api_key: + os.environ["OPENAI_API_KEY"] = self.original_api_key + else: + os.environ.pop("OPENAI_API_KEY", None) + + def test_setup_wizard_menu_available_without_auth(self): + """Test Setup Wizard offers options even without auth.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + # Clear any auth + os.environ.pop("OPENAI_API_KEY", None) + + # Menu should still exist and offer options + assert len(PROVIDER_MENU) > 0 + assert any(e.provider_key == "openai" for e in PROVIDER_MENU) + + +class TestAPIKeyPriority(unittest.TestCase): + """Test that API Key check is working correctly.""" + + def setUp(self): + """Set up test environment.""" + self.original_api_key = os.environ.get("OPENAI_API_KEY") + + def tearDown(self): + """Restore environment.""" + if self.original_api_key: + os.environ["OPENAI_API_KEY"] = self.original_api_key + else: + os.environ.pop("OPENAI_API_KEY", None) + + def test_api_key_string_detection(self): + """Test that API key detection works with valid key format.""" + from pantheon.utils.model_selector import ModelSelector + + # Set a properly formatted API key + os.environ["OPENAI_API_KEY"] = "sk-proj-abcdef123456" + + selector = ModelSelector(None) + provider = selector.detect_available_provider() + + assert provider == "openai" + + def test_empty_api_key_not_detected(self): + """Test that empty API key is not detected as valid.""" + from pantheon.utils.model_selector import ModelSelector + + # Set empty API key + os.environ["OPENAI_API_KEY"] = "" + + selector = ModelSelector(None) + provider = selector.detect_available_provider() + + # Should not detect empty string as valid provider + assert provider != "openai" or provider is None + + +# Pytest-style integration tests +@pytest.mark.integration +class TestBackwardCompatibilityIntegration: + """Integration tests for backward compatibility.""" + + def test_api_key_full_flow(self): + """Test complete flow with API Key authentication.""" + from pantheon.utils.model_selector import ModelSelector + + os.environ["OPENAI_API_KEY"] = "sk-test123" + + selector = ModelSelector(None) + + # Detect provider + assert selector.detect_available_provider() == "openai" + + # Resolve models + models = selector.resolve_model("normal") + assert isinstance(models, list) + + def test_api_key_and_oauth_menu_both_present(self): + """Test that both auth options are in Setup Wizard.""" + from pantheon.repl.setup_wizard import PROVIDER_MENU + + provider_keys = [e.provider_key for e in PROVIDER_MENU] + + # Both must be present + assert "openai" in provider_keys + assert "openai_oauth" in provider_keys + + # Count should be 2 for OpenAI options + openai_count = sum( + 1 + for e in PROVIDER_MENU + if e.provider_key in ["openai", "openai_oauth"] + ) + assert openai_count == 2 + + +if __name__ == "__main__": + unittest.main(argv=[""], exit=False, verbosity=2) diff --git a/tests/test_model_selector.py b/tests/test_model_selector.py index 7d5b83925..8fd4c58de 100644 --- a/tests/test_model_selector.py +++ b/tests/test_model_selector.py @@ -125,6 +125,15 @@ def test_fallback_to_available_not_in_priority(self, mock_settings): result = selector.detect_available_provider() assert result == "deepseek" + def test_openai_not_available_when_api_key_routing_disabled(self, mock_settings): + selector = ModelSelector(mock_settings) + selector._available_providers = None + + with patch("pantheon.utils.model_selector.should_treat_openai_api_key_as_available", return_value=False): + with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test123"}, clear=False): + available = selector._get_available_providers() + assert "openai" not in available + class TestModelResolution: """Test model tag resolution.""" diff --git a/tests/test_openai_auth_strategy.py b/tests/test_openai_auth_strategy.py new file mode 100644 index 000000000..b0fa16443 --- /dev/null +++ b/tests/test_openai_auth_strategy.py @@ -0,0 +1,72 @@ +from unittest.mock import patch + +import asyncio +import pytest + +from pantheon.auth.openai_auth_strategy import ( + get_openai_auth_settings, + is_api_key_auth_enabled, + is_oauth_auth_enabled, + should_use_codex_oauth_transport, +) + + +def _mock_settings(auth_openai: dict): + class _Settings: + def get(self, key, default=None): + return auth_openai if key == "auth.openai" else default + return _Settings() + + +def test_api_key_can_be_disabled_by_settings(): + with patch("pantheon.auth.openai_auth_strategy.get_settings", return_value=_mock_settings({ + "mode": "auto", + "enable_api_key": False, + "enable_oauth": True, + })): + prefs = get_openai_auth_settings() + assert prefs.mode == "auto" + assert is_api_key_auth_enabled() is False + assert is_oauth_auth_enabled() is True + + +def test_oauth_only_disables_api_key_routing(): + with patch("pantheon.auth.openai_auth_strategy.get_settings", return_value=_mock_settings({ + "mode": "oauth_only", + "enable_api_key": True, + "enable_oauth": True, + })): + assert is_api_key_auth_enabled() is False + assert is_oauth_auth_enabled() is True + assert should_use_codex_oauth_transport("codex/gpt-5.4") is True + + +def test_api_key_only_disables_codex_oauth_transport(): + with patch("pantheon.auth.openai_auth_strategy.get_settings", return_value=_mock_settings({ + "mode": "api_key_only", + "enable_api_key": True, + "enable_oauth": True, + })): + assert is_api_key_auth_enabled() is True + assert is_oauth_auth_enabled() is False + assert should_use_codex_oauth_transport("codex/gpt-5.4") is False + + +def test_codex_model_respects_disabled_oauth(): + from pantheon.utils.llm_providers import ProviderConfig, ProviderType, call_llm_provider + + with patch("pantheon.auth.openai_auth_strategy.get_settings", return_value=_mock_settings({ + "mode": "api_key_only", + "enable_api_key": True, + "enable_oauth": True, + })): + with pytest.raises(RuntimeError, match="Codex OAuth transport is disabled"): + asyncio.run( + call_llm_provider( + config=ProviderConfig( + provider_type=ProviderType.OPENAI, + model_name="codex/gpt-5.4", + ), + messages=[{"role": "user", "content": "hi"}], + ) + ) diff --git a/tests/test_openai_provider_security.py b/tests/test_openai_provider_security.py new file mode 100644 index 000000000..16641fca5 --- /dev/null +++ b/tests/test_openai_provider_security.py @@ -0,0 +1,31 @@ +from unittest.mock import patch + +from pantheon.auth import openai_provider + + +def test_check_origin_accepts_allowed_origin(): + handler = openai_provider._OAuthCallbackHandler.__new__(openai_provider._OAuthCallbackHandler) + handler.headers = {"Origin": "https://auth.openai.com/some/path"} + assert handler._check_origin() is True + + +def test_check_origin_rejects_untrusted_origin(): + handler = openai_provider._OAuthCallbackHandler.__new__(openai_provider._OAuthCallbackHandler) + handler.headers = {"Origin": "https://evil.example.com"} + assert handler._check_origin() is False + + +def test_decode_jwt_payload_does_not_fallback_for_sensitive_claims(): + with patch.object(openai_provider, "_decode_jwt_payload_verified", return_value={}): + with patch.object(openai_provider, "_decode_jwt_payload_unverified", return_value={"email": "forged@example.com"}): + payload = openai_provider._decode_jwt_payload("fake-token") + assert payload == {} + assert openai_provider._extract_email("fake-token") == "" + + +def test_decode_jwt_payload_allows_unverified_fallback_for_exp_only(): + with patch.object(openai_provider, "_decode_jwt_payload_verified", return_value={}): + with patch.object(openai_provider, "_decode_jwt_payload_unverified", return_value={"exp": 2000000000}): + payload = openai_provider._decode_jwt_payload("fake-token", allow_unverified_fallback=True) + assert payload == {"exp": 2000000000} + assert openai_provider._extract_token_exp("fake-token") == 2000000000.0