diff --git a/core/client.py b/core/client.py index 0714785..2bd9ae2 100644 --- a/core/client.py +++ b/core/client.py @@ -53,7 +53,7 @@ def _get_headers(self) -> dict[str, str]: return { "accept": "application/json", - "authorization": f"Bearer {self.api_token}", + "authorization": f"Bearer {token}", "content-type": "application/json", } diff --git a/core/config.py b/core/config.py index 0e81a8f..58ce992 100644 --- a/core/config.py +++ b/core/config.py @@ -36,6 +36,19 @@ class Settings: transport: str = field(default_factory=lambda: os.getenv("MCP_TRANSPORT", "stdio")) log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO")) + # OAuth / Remote Auth Configuration + server_url: str = field(default_factory=lambda: os.getenv("MCP_SERVER_URL", "")) + auth_base_url: str = field( + default_factory=lambda: os.getenv( + "ACEDATACLOUD_AUTH_BASE_URL", "https://auth.acedata.cloud" + ) + ) + platform_base_url: str = field( + default_factory=lambda: os.getenv( + "ACEDATACLOUD_PLATFORM_BASE_URL", "https://platform.acedata.cloud" + ) + ) + def validate(self) -> None: """Validate required settings.""" if not self.api_token: diff --git a/core/oauth.py b/core/oauth.py new file mode 100644 index 0000000..fd57869 --- /dev/null +++ b/core/oauth.py @@ -0,0 +1,299 @@ +"""OAuth 2.1 provider for AceDataCloud MCP servers. + +Implements the MCP SDK's OAuthAuthorizationServerProvider interface, +delegating user authentication to AceDataCloud's AuthBackend. + +Flow: +1. Claude.ai redirects user to /authorize +2. MCP server redirects to auth.acedata.cloud login +3. User logs in, auth redirects back to /oauth/callback with code +4. MCP server exchanges code for JWT, fetches user's API credential +5. Issues the credential token as the OAuth access_token +6. Claude uses this token for all subsequent MCP requests +""" + +import secrets +import time +from urllib.parse import quote, urlencode + +import httpx +from loguru import logger +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthClientInformationFull, + OAuthToken, + RefreshToken, +) +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse + +from core.client import set_request_api_token +from core.config import settings + + +class AceDataCloudOAuthProvider: + """OAuth provider that delegates authentication to AceDataCloud platform. + + In-memory storage is used for auth state (suitable for single-replica K8s deployment). + """ + + def __init__(self) -> None: + self._clients: dict[str, OAuthClientInformationFull] = {} + self._auth_codes: dict[ + str, tuple[AuthorizationCode, str] + ] = {} # code → (AuthCode, api_token) + self._access_tokens: dict[str, AccessToken] = {} + self._refresh_tokens: dict[str, RefreshToken] = {} + self._pending_auth: dict[str, dict] = {} # mcp_state → {client_id, params} + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self._clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + client_id = client_info.client_id + assert client_id is not None + self._clients[client_id] = client_info + logger.info(f"Registered OAuth client: {client_id}") + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Redirect user to AceDataCloud login page.""" + # Generate state key for tracking this auth flow + mcp_state = secrets.token_urlsafe(32) + self._pending_auth[mcp_state] = { + "client_id": client.client_id, + "redirect_uri": str(params.redirect_uri), + "state": params.state, + "code_challenge": params.code_challenge, + "redirect_uri_provided_explicitly": params.redirect_uri_provided_explicitly, + "scopes": params.scopes, + "resource": params.resource, + } + + # Build callback URL with mcp_state + callback_url = f"{settings.server_url}/oauth/callback?mcp_state={mcp_state}" + + # Redirect to AceDataCloud auth login + auth_login_url = ( + f"{settings.auth_base_url}/auth/login?redirect={quote(callback_url, safe='')}" + ) + logger.info(f"OAuth authorize: redirecting to AceDataCloud auth (mcp_state={mcp_state})") + return auth_login_url + + async def handle_callback(self, request: Request) -> RedirectResponse | JSONResponse: + """Handle the callback from AceDataCloud auth after user login. + + This is called as a Starlette route handler, not part of the SDK interface. + """ + mcp_state = request.query_params.get("mcp_state") + adc_code = request.query_params.get("code") + + if not mcp_state or not adc_code: + return JSONResponse({"error": "Missing mcp_state or code parameter"}, status_code=400) + + pending = self._pending_auth.pop(mcp_state, None) + if not pending: + return JSONResponse({"error": "Invalid or expired mcp_state"}, status_code=400) + + try: + # Exchange AceDataCloud code for JWT + jwt_token = await self._exchange_code_for_jwt(adc_code) + if not jwt_token: + return JSONResponse( + {"error": "Failed to exchange authorization code"}, status_code=502 + ) + + # Fetch user's API credential token from PlatformBackend + api_token = await self._get_user_credential(jwt_token) + if not api_token: + return JSONResponse( + { + "error": "No API credential found. Please create an API key at " + "https://platform.acedata.cloud first." + }, + status_code=403, + ) + + # Create MCP authorization code + auth_code_str = secrets.token_urlsafe(48) + auth_code = AuthorizationCode( + code=auth_code_str, + scopes=pending.get("scopes") or [], + expires_at=time.time() + 600, # 10 minutes + client_id=pending["client_id"], + code_challenge=pending["code_challenge"], + redirect_uri=pending["redirect_uri"], + redirect_uri_provided_explicitly=pending["redirect_uri_provided_explicitly"], + resource=pending.get("resource"), + ) + self._auth_codes[auth_code_str] = (auth_code, api_token) + + # Redirect back to Claude with the MCP auth code + redirect_uri = pending["redirect_uri"] + params = {"code": auth_code_str} + if pending.get("state"): + params["state"] = pending["state"] + + separator = "&" if "?" in redirect_uri else "?" + redirect_url = f"{redirect_uri}{separator}{urlencode(params)}" + logger.info("OAuth callback: issuing auth code, redirecting to client") + return RedirectResponse(url=redirect_url, status_code=302) + + except Exception: + logger.exception("OAuth callback error") + return JSONResponse({"error": "Internal server error"}, status_code=500) + + async def load_authorization_code( + self, + client: OAuthClientInformationFull, # noqa: ARG002 + authorization_code: str, + ) -> AuthorizationCode | None: + data = self._auth_codes.get(authorization_code) + if not data: + return None + auth_code, _ = data + if auth_code.expires_at < time.time(): + self._auth_codes.pop(authorization_code, None) + return None + return auth_code + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + data = self._auth_codes.pop(authorization_code.code, None) + if not data: + raise ValueError("Authorization code not found or already used") + _, api_token = data + + client_id = client.client_id or "" + + # Store access token mapping + self._access_tokens[api_token] = AccessToken( + token=api_token, + client_id=client_id, + scopes=authorization_code.scopes, + expires_at=None, # API credential tokens don't expire by time + ) + + # Generate refresh token + refresh_token_str = secrets.token_urlsafe(48) + self._refresh_tokens[refresh_token_str] = RefreshToken( + token=refresh_token_str, + client_id=client_id, + scopes=authorization_code.scopes, + ) + + logger.info(f"OAuth token exchange: issued access token for client {client_id}") + return OAuthToken( + access_token=api_token, + token_type="Bearer", + refresh_token=refresh_token_str, + ) + + async def load_refresh_token( + self, + client: OAuthClientInformationFull, # noqa: ARG002 + refresh_token: str, + ) -> RefreshToken | None: + return self._refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + # For refresh, we reuse the same API credential token + # Find the associated access token + self._refresh_tokens.pop(refresh_token.token, None) + + # The original access_token (API credential) is still valid + # Just issue a new refresh token + client_id = client.client_id or "" + new_refresh = secrets.token_urlsafe(48) + self._refresh_tokens[new_refresh] = RefreshToken( + token=new_refresh, + client_id=client_id, + scopes=scopes or refresh_token.scopes, + ) + + # Find the access token for this client + for token, at in self._access_tokens.items(): + if at.client_id == client.client_id: + return OAuthToken( + access_token=token, + token_type="Bearer", + refresh_token=new_refresh, + ) + + raise ValueError("No access token found for refresh") + + async def load_access_token(self, token: str) -> AccessToken | None: + """Validate an access token. + + Accepts both OAuth-issued tokens and direct API credential tokens. + Direct tokens are accepted since the real validation happens at api.acedata.cloud. + """ + # Check OAuth-issued tokens first + if token in self._access_tokens: + access_token = self._access_tokens[token] + if access_token.expires_at and time.time() > access_token.expires_at: + self._access_tokens.pop(token, None) + return None + set_request_api_token(token) + return access_token + + # Accept direct API credential tokens (for VS Code, Cursor, etc.) + set_request_api_token(token) + return AccessToken(token=token, client_id="direct", scopes=[]) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + if isinstance(token, AccessToken): + self._access_tokens.pop(token.token, None) + elif isinstance(token, RefreshToken): + self._refresh_tokens.pop(token.token, None) + logger.info(f"Revoked token: {token.token[:8]}...") + + # --- Internal helpers --- + + async def _exchange_code_for_jwt(self, code: str) -> str | None: + """Exchange AceDataCloud auth code for JWT via legacy token endpoint.""" + try: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + f"{settings.auth_base_url}/oauth2/v1/token", + json={"code": code}, + ) + if response.status_code == 200: + data = response.json() + access_token: str | None = data.get("access_token") + return access_token + logger.error(f"Auth code exchange failed: {response.status_code} {response.text}") + except Exception: + logger.exception("Auth code exchange error") + return None + + async def _get_user_credential(self, jwt_token: str) -> str | None: + """Fetch user's first available API credential token from PlatformBackend.""" + try: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.get( + f"{settings.platform_base_url}/api/v1/credentials/", + headers={"Authorization": f"Bearer {jwt_token}"}, + ) + if response.status_code == 200: + data = response.json() + results = data.get("results", data) if isinstance(data, dict) else data + if isinstance(results, list): + for cred in results: + cred_token: str | None = cred.get("token") + if cred_token: + logger.info("Found user credential token") + return cred_token + logger.warning(f"No credentials found: {response.status_code}") + except Exception: + logger.exception("Credential fetch error") + return None diff --git a/core/server.py b/core/server.py index 438b5be..4268bf3 100644 --- a/core/server.py +++ b/core/server.py @@ -18,7 +18,27 @@ logger = logging.getLogger(__name__) +# Build FastMCP kwargs, enabling OAuth when MCP_SERVER_URL is configured +mcp_kwargs: dict = {"host": "0.0.0.0"} +oauth_provider = None + +if settings.server_url: + from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions + from pydantic import AnyHttpUrl + + from core.oauth import AceDataCloudOAuthProvider + + oauth_provider = AceDataCloudOAuthProvider() + mcp_kwargs["auth_server_provider"] = oauth_provider + mcp_kwargs["auth"] = AuthSettings( + issuer_url=AnyHttpUrl(settings.server_url), + resource_server_url=AnyHttpUrl(settings.server_url), + client_registration_options=ClientRegistrationOptions(enabled=True), + revocation_options=RevocationOptions(enabled=True), + ) + logger.info(f"OAuth enabled: issuer_url={settings.server_url}") + # Initialize FastMCP server -mcp = FastMCP(settings.server_name, host="0.0.0.0") +mcp = FastMCP(settings.server_name, **mcp_kwargs) logger.info(f"Initialized MCP server: {settings.server_name}") diff --git a/deploy/production/deployment.yaml b/deploy/production/deployment.yaml index ed22d54..9b72457 100644 --- a/deploy/production/deployment.yaml +++ b/deploy/production/deployment.yaml @@ -28,6 +28,9 @@ spec: ports: - containerPort: 8000 protocol: TCP + env: + - name: MCP_SERVER_URL + value: "https://luma.mcp.acedata.cloud" resources: limits: cpu: 500m diff --git a/main.py b/main.py index 5934fa9..68e60c2 100644 --- a/main.py +++ b/main.py @@ -147,42 +147,7 @@ def main() -> None: from starlette.responses import JSONResponse from starlette.routing import Mount, Route - from core.client import set_request_api_token - - class BearerAuthMiddleware: - """ASGI middleware that extracts Bearer token and rejects - unauthenticated requests (except /health).""" - - def __init__(self, app): # type: ignore[no-untyped-def] - self.app = app - - async def __call__(self, scope, receive, send): # type: ignore[no-untyped-def] - if scope["type"] == "http": - path = scope.get("path", "") - if ( - path == "/health" - or path.startswith("/.well-known/") - or path.startswith("/mcp") - ): - await self.app(scope, receive, send) - return - headers = dict(scope.get("headers", [])) - # Allow SmitheryBot scan requests through for registry scanning - user_agent = headers.get(b"user-agent", b"").decode() - if user_agent.startswith("SmitheryBot/"): - await self.app(scope, receive, send) - return - auth = headers.get(b"authorization", b"").decode() - if auth.startswith("Bearer "): - set_request_api_token(auth[7:]) - else: - response = JSONResponse( - {"error": "Missing or invalid Authorization header"}, - status_code=401, - ) - await response(scope, receive, send) - return - await self.app(scope, receive, send) + from core.server import oauth_provider async def health(_request: Request) -> JSONResponse: return JSONResponse({"status": "ok"}) @@ -239,15 +204,19 @@ async def lifespan(_app: Starlette): # type: ignore[no-untyped-def] mcp.settings.json_response = True mcp.settings.streamable_http_path = "/mcp" - app = Starlette( - routes=[ - Route("/health", health), - Route("/.well-known/mcp/server-card.json", server_card), - Mount("/", app=mcp.streamable_http_app()), - ], - lifespan=lifespan, - ) - app.add_middleware(BearerAuthMiddleware) + # Build routes + routes: list[Route | Mount] = [ + Route("/health", health), + Route("/.well-known/mcp/server-card.json", server_card), + ] + + # Add OAuth callback route if OAuth is enabled + if oauth_provider: + routes.append(Route("/oauth/callback", oauth_provider.handle_callback)) + + routes.append(Mount("/", app=mcp.streamable_http_app())) + + app = Starlette(routes=routes, lifespan=lifespan) uvicorn.run(app, host="0.0.0.0", port=args.port) else: mcp.run(transport="stdio")