Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_headers(self) -> dict[str, str]:

return {
"accept": "application/json",
"authorization": f"Bearer {self.api_token}",
"authorization": f"Bearer {token}",
"content-type": "application/json",
}

Expand Down
13 changes: 13 additions & 0 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ class Settings:
transport: str = field(default_factory=lambda: os.getenv("MCP_TRANSPORT", "stdio"))
log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO"))

# OAuth / Remote Auth Configuration
server_url: str = field(default_factory=lambda: os.getenv("MCP_SERVER_URL", ""))
auth_base_url: str = field(
default_factory=lambda: os.getenv(
"ACEDATACLOUD_AUTH_BASE_URL", "https://auth.acedata.cloud"
)
)
platform_base_url: str = field(
default_factory=lambda: os.getenv(
"ACEDATACLOUD_PLATFORM_BASE_URL", "https://platform.acedata.cloud"
)
)

def validate(self) -> None:
"""Validate required settings."""
if not self.api_token:
Expand Down
299 changes: 299 additions & 0 deletions core/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
"""OAuth 2.1 provider for AceDataCloud MCP servers.

Implements the MCP SDK's OAuthAuthorizationServerProvider interface,
delegating user authentication to AceDataCloud's AuthBackend.

Flow:
1. Claude.ai redirects user to /authorize
2. MCP server redirects to auth.acedata.cloud login
3. User logs in, auth redirects back to /oauth/callback with code
4. MCP server exchanges code for JWT, fetches user's API credential
5. Issues the credential token as the OAuth access_token
6. Claude uses this token for all subsequent MCP requests
"""

import secrets
import time
from urllib.parse import quote, urlencode

import httpx
from loguru import logger
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
AuthorizationParams,
OAuthClientInformationFull,
OAuthToken,
RefreshToken,
)
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse

from core.client import set_request_api_token
from core.config import settings


class AceDataCloudOAuthProvider:
"""OAuth provider that delegates authentication to AceDataCloud platform.

In-memory storage is used for auth state (suitable for single-replica K8s deployment).
"""

def __init__(self) -> None:
self._clients: dict[str, OAuthClientInformationFull] = {}
self._auth_codes: dict[
str, tuple[AuthorizationCode, str]
] = {} # code → (AuthCode, api_token)
self._access_tokens: dict[str, AccessToken] = {}
self._refresh_tokens: dict[str, RefreshToken] = {}
self._pending_auth: dict[str, dict] = {} # mcp_state → {client_id, params}

async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
return self._clients.get(client_id)

async def register_client(self, client_info: OAuthClientInformationFull) -> None:
client_id = client_info.client_id
assert client_id is not None
self._clients[client_id] = client_info
logger.info(f"Registered OAuth client: {client_id}")

async def authorize(
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
"""Redirect user to AceDataCloud login page."""
# Generate state key for tracking this auth flow
mcp_state = secrets.token_urlsafe(32)
self._pending_auth[mcp_state] = {
"client_id": client.client_id,
"redirect_uri": str(params.redirect_uri),
"state": params.state,
"code_challenge": params.code_challenge,
"redirect_uri_provided_explicitly": params.redirect_uri_provided_explicitly,
"scopes": params.scopes,
"resource": params.resource,
}

# Build callback URL with mcp_state
callback_url = f"{settings.server_url}/oauth/callback?mcp_state={mcp_state}"

# Redirect to AceDataCloud auth login
auth_login_url = (
f"{settings.auth_base_url}/auth/login?redirect={quote(callback_url, safe='')}"
)
logger.info(f"OAuth authorize: redirecting to AceDataCloud auth (mcp_state={mcp_state})")
return auth_login_url

async def handle_callback(self, request: Request) -> RedirectResponse | JSONResponse:
"""Handle the callback from AceDataCloud auth after user login.

This is called as a Starlette route handler, not part of the SDK interface.
"""
mcp_state = request.query_params.get("mcp_state")
adc_code = request.query_params.get("code")

if not mcp_state or not adc_code:
return JSONResponse({"error": "Missing mcp_state or code parameter"}, status_code=400)

pending = self._pending_auth.pop(mcp_state, None)
if not pending:
return JSONResponse({"error": "Invalid or expired mcp_state"}, status_code=400)

try:
# Exchange AceDataCloud code for JWT
jwt_token = await self._exchange_code_for_jwt(adc_code)
if not jwt_token:
return JSONResponse(
{"error": "Failed to exchange authorization code"}, status_code=502
)

# Fetch user's API credential token from PlatformBackend
api_token = await self._get_user_credential(jwt_token)
if not api_token:
return JSONResponse(
{
"error": "No API credential found. Please create an API key at "
"https://platform.acedata.cloud first."
},
status_code=403,
)

# Create MCP authorization code
auth_code_str = secrets.token_urlsafe(48)
auth_code = AuthorizationCode(
code=auth_code_str,
scopes=pending.get("scopes") or [],
expires_at=time.time() + 600, # 10 minutes
client_id=pending["client_id"],
code_challenge=pending["code_challenge"],
redirect_uri=pending["redirect_uri"],
redirect_uri_provided_explicitly=pending["redirect_uri_provided_explicitly"],
resource=pending.get("resource"),
)
self._auth_codes[auth_code_str] = (auth_code, api_token)

# Redirect back to Claude with the MCP auth code
redirect_uri = pending["redirect_uri"]
params = {"code": auth_code_str}
if pending.get("state"):
params["state"] = pending["state"]

separator = "&" if "?" in redirect_uri else "?"
redirect_url = f"{redirect_uri}{separator}{urlencode(params)}"
logger.info("OAuth callback: issuing auth code, redirecting to client")
return RedirectResponse(url=redirect_url, status_code=302)

except Exception:
logger.exception("OAuth callback error")
return JSONResponse({"error": "Internal server error"}, status_code=500)

async def load_authorization_code(
self,
client: OAuthClientInformationFull, # noqa: ARG002
authorization_code: str,
) -> AuthorizationCode | None:
data = self._auth_codes.get(authorization_code)
if not data:
return None
auth_code, _ = data
if auth_code.expires_at < time.time():
self._auth_codes.pop(authorization_code, None)
return None
return auth_code

async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
) -> OAuthToken:
data = self._auth_codes.pop(authorization_code.code, None)
if not data:
raise ValueError("Authorization code not found or already used")
_, api_token = data

client_id = client.client_id or ""

# Store access token mapping
self._access_tokens[api_token] = AccessToken(
token=api_token,
client_id=client_id,
scopes=authorization_code.scopes,
expires_at=None, # API credential tokens don't expire by time
)

# Generate refresh token
refresh_token_str = secrets.token_urlsafe(48)
self._refresh_tokens[refresh_token_str] = RefreshToken(
token=refresh_token_str,
client_id=client_id,
scopes=authorization_code.scopes,
)

logger.info(f"OAuth token exchange: issued access token for client {client_id}")
return OAuthToken(
access_token=api_token,
token_type="Bearer",
refresh_token=refresh_token_str,
)

async def load_refresh_token(
self,
client: OAuthClientInformationFull, # noqa: ARG002
refresh_token: str,
) -> RefreshToken | None:
return self._refresh_tokens.get(refresh_token)

async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshToken,
scopes: list[str],
) -> OAuthToken:
# For refresh, we reuse the same API credential token
# Find the associated access token
self._refresh_tokens.pop(refresh_token.token, None)

# The original access_token (API credential) is still valid
# Just issue a new refresh token
client_id = client.client_id or ""
new_refresh = secrets.token_urlsafe(48)
self._refresh_tokens[new_refresh] = RefreshToken(
token=new_refresh,
client_id=client_id,
scopes=scopes or refresh_token.scopes,
)

# Find the access token for this client
for token, at in self._access_tokens.items():
if at.client_id == client.client_id:
return OAuthToken(
access_token=token,
token_type="Bearer",
refresh_token=new_refresh,
)

raise ValueError("No access token found for refresh")

async def load_access_token(self, token: str) -> AccessToken | None:
"""Validate an access token.

Accepts both OAuth-issued tokens and direct API credential tokens.
Direct tokens are accepted since the real validation happens at api.acedata.cloud.
"""
# Check OAuth-issued tokens first
if token in self._access_tokens:
access_token = self._access_tokens[token]
if access_token.expires_at and time.time() > access_token.expires_at:
self._access_tokens.pop(token, None)
return None
set_request_api_token(token)
return access_token

# Accept direct API credential tokens (for VS Code, Cursor, etc.)
set_request_api_token(token)
return AccessToken(token=token, client_id="direct", scopes=[])

async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
if isinstance(token, AccessToken):
self._access_tokens.pop(token.token, None)
elif isinstance(token, RefreshToken):
self._refresh_tokens.pop(token.token, None)
logger.info(f"Revoked token: {token.token[:8]}...")

# --- Internal helpers ---

async def _exchange_code_for_jwt(self, code: str) -> str | None:
"""Exchange AceDataCloud auth code for JWT via legacy token endpoint."""
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
f"{settings.auth_base_url}/oauth2/v1/token",
json={"code": code},
)
if response.status_code == 200:
data = response.json()
access_token: str | None = data.get("access_token")
return access_token
logger.error(f"Auth code exchange failed: {response.status_code} {response.text}")
except Exception:
logger.exception("Auth code exchange error")
return None

async def _get_user_credential(self, jwt_token: str) -> str | None:
"""Fetch user's first available API credential token from PlatformBackend."""
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(
f"{settings.platform_base_url}/api/v1/credentials/",
headers={"Authorization": f"Bearer {jwt_token}"},
)
if response.status_code == 200:
data = response.json()
results = data.get("results", data) if isinstance(data, dict) else data
if isinstance(results, list):
for cred in results:
cred_token: str | None = cred.get("token")
if cred_token:
logger.info("Found user credential token")
return cred_token
logger.warning(f"No credentials found: {response.status_code}")
except Exception:
logger.exception("Credential fetch error")
return None
22 changes: 21 additions & 1 deletion core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,27 @@

logger = logging.getLogger(__name__)

# Build FastMCP kwargs, enabling OAuth when MCP_SERVER_URL is configured
mcp_kwargs: dict = {"host": "0.0.0.0"}
oauth_provider = None

if settings.server_url:
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions
from pydantic import AnyHttpUrl

from core.oauth import AceDataCloudOAuthProvider

oauth_provider = AceDataCloudOAuthProvider()
mcp_kwargs["auth_server_provider"] = oauth_provider
mcp_kwargs["auth"] = AuthSettings(
issuer_url=AnyHttpUrl(settings.server_url),
resource_server_url=AnyHttpUrl(settings.server_url),
client_registration_options=ClientRegistrationOptions(enabled=True),
revocation_options=RevocationOptions(enabled=True),
)
logger.info(f"OAuth enabled: issuer_url={settings.server_url}")

# Initialize FastMCP server
mcp = FastMCP(settings.server_name, host="0.0.0.0")
mcp = FastMCP(settings.server_name, **mcp_kwargs)

logger.info(f"Initialized MCP server: {settings.server_name}")
3 changes: 3 additions & 0 deletions deploy/production/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ spec:
ports:
- containerPort: 8000
protocol: TCP
env:
- name: MCP_SERVER_URL
value: "https://luma.mcp.acedata.cloud"
resources:
limits:
cpu: 500m
Expand Down
Loading