diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..406a96b --- /dev/null +++ b/.env.example @@ -0,0 +1,83 @@ +# NextMCP OAuth Integration Testing Environment Variables +# +# This file is a template for setting up OAuth credentials for integration testing. +# Copy this file to .env and fill in your actual values. +# +# Setup Instructions: +# 1. Copy this file: cp .env.example .env +# 2. Follow docs/OAUTH_TESTING_SETUP.md to get your credentials +# 3. Use examples/auth/oauth_token_helper.py to obtain tokens +# 4. Fill in the values below +# 5. Load with: export $(cat .env | xargs) +# 6. Run tests: pytest tests/test_oauth_integration.py -v -m integration +# +# IMPORTANT: Never commit .env files! They contain secrets. +# The .env file is already in .gitignore. + +# ============================================================================= +# GITHUB OAUTH CREDENTIALS +# ============================================================================= +# Get these from: https://github.com/settings/developers +# Create a new OAuth App with callback URL: http://localhost:8080/oauth/callback + +GITHUB_CLIENT_ID=your_github_client_id_here +GITHUB_CLIENT_SECRET=your_github_client_secret_here + +# Get this token by running: +# python examples/auth/oauth_token_helper.py --provider github +# Or follow manual instructions in docs/OAUTH_TESTING_SETUP.md +GITHUB_ACCESS_TOKEN=gho_your_github_access_token_here + +# ============================================================================= +# GOOGLE OAUTH CREDENTIALS +# ============================================================================= +# Get these from: https://console.cloud.google.com +# Create OAuth 2.0 credentials with callback URL: http://localhost:8080/oauth/callback +# Enable APIs: Google Drive API, Gmail API + +GOOGLE_CLIENT_ID=your_google_client_id.apps.googleusercontent.com +GOOGLE_CLIENT_SECRET=your_google_client_secret_here + +# Get these tokens by running: +# python examples/auth/oauth_token_helper.py --provider google +# Or follow manual instructions in docs/OAUTH_TESTING_SETUP.md +GOOGLE_ACCESS_TOKEN=ya29.your_google_access_token_here + +# Refresh token (optional, for token refresh tests) +# Only issued on first authorization with offline access +GOOGLE_REFRESH_TOKEN=1//your_google_refresh_token_here + +# ============================================================================= +# LOADING INSTRUCTIONS +# ============================================================================= +# Option 1: Load into current shell +# export $(cat .env | xargs) +# +# Option 2: Load with grep to filter comments +# export $(grep -v '^#' .env | xargs) +# +# Option 3: Source in your shell config +# Add to ~/.bashrc or ~/.zshrc: +# if [ -f /path/to/nextmcp/.env ]; then +# export $(cat /path/to/nextmcp/.env | grep -v '^#' | xargs) +# fi +# +# Option 4: Use python-dotenv in tests +# from dotenv import load_dotenv +# load_dotenv() + +# ============================================================================= +# VERIFICATION +# ============================================================================= +# Verify environment variables are set: +# echo $GITHUB_CLIENT_ID +# echo $GITHUB_ACCESS_TOKEN +# echo $GOOGLE_CLIENT_ID +# echo $GOOGLE_ACCESS_TOKEN +# +# Run integration tests: +# pytest tests/test_oauth_integration.py -v -m integration +# +# Run specific provider tests: +# pytest tests/test_oauth_integration.py::TestGitHubOAuthIntegration -v +# pytest tests/test_oauth_integration.py::TestGoogleOAuthIntegration -v diff --git a/.gitignore b/.gitignore index 2993231..e771d13 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,13 @@ ENV/ # Environment variables .env .env.local +.env.test +.env.production +*_credentials.json +*_token.json + +# But DO track the example file +!.env.example # Testing .pytest_cache/ diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..38a966b --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,726 @@ +# NextMCP Authentication Architecture + +This document explains how NextMCP's authentication system works internally, how the components fit together, and the design decisions behind it. + +--- + +## Table of Contents + +1. [System Overview](#system-overview) +2. [Component Architecture](#component-architecture) +3. [Request Flow](#request-flow) +4. [Data Flow](#data-flow) +5. [Design Decisions](#design-decisions) +6. [Security Considerations](#security-considerations) +7. [Performance Characteristics](#performance-characteristics) + +--- + +## System Overview + +NextMCP's auth system consists of three main layers: + +``` +┌──────────────────────────────────────────────────────────────┐ +│ MCP Client/Host │ +│ (Claude Desktop, Cursor, etc.) │ +└────────────────────┬─────────────────────────────────────────┘ + │ + │ 1. Reads auth metadata + │ 2. Initiates OAuth flow + │ 3. Sends requests with tokens + │ +┌────────────────────▼─────────────────────────────────────────┐ +│ Auth Metadata Protocol │ +│ • Announces auth requirements │ +│ • Lists providers, scopes, permissions │ +│ • JSON schema for validation │ +└────────────────────┬─────────────────────────────────────────┘ + │ +┌────────────────────▼─────────────────────────────────────────┐ +│ Request Enforcement Middleware │ +│ • Validates every request │ +│ • Checks auth credentials │ +│ • Enforces scopes/permissions │ +│ • Manages sessions │ +└────────────────────┬─────────────────────────────────────────┘ + │ + ┌───────────┴──────────┐ + │ │ +┌────────▼──────────┐ ┌────────▼──────────┐ +│ OAuth Providers │ │ Session Store │ +│ • GitHub │ │ • Memory │ +│ • Google │ │ • File │ +│ • Custom │ │ • Redis (future) │ +└───────────────────┘ └───────────────────┘ +``` + +--- + +## Component Architecture + +### 1. Auth Metadata Protocol + +**Location**: `nextmcp/protocol/auth_metadata.py` + +**Purpose**: Allows servers to **announce** their auth requirements so hosts can discover them. + +**Key Classes**: + +```python +class AuthMetadata: + """Top-level auth requirements.""" + requirement: AuthRequirement # REQUIRED, OPTIONAL, NONE + providers: list[AuthProviderMetadata] + required_scopes: list[str] + optional_scopes: list[str] + permissions: list[str] + supports_multi_user: bool + token_refresh_enabled: bool + +class AuthProviderMetadata: + """Single OAuth provider info.""" + name: str # "google", "github" + type: str # "oauth2" + flows: list[AuthFlowType] # [OAUTH2_PKCE] + authorization_url: str + token_url: str + scopes: list[str] + supports_refresh: bool +``` + +**Usage**: + +```python +# Server creates metadata +metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) +metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + ... +) + +# Serialize for transmission +json_data = metadata.to_dict() + +# Host reads and understands auth requirements +metadata = AuthMetadata.from_dict(json_data) +``` + +**Why it exists**: Before this, hosts had no way to know what auth a server needed. Now they can: +- Show "Connect Google Account" UI +- Request appropriate OAuth scopes +- Handle auth failures gracefully + +--- + +### 2. Session Store + +**Location**: `nextmcp/session/session_store.py` + +**Purpose**: Persistent storage for OAuth tokens and user sessions + +**Interface**: + +```python +class SessionStore(ABC): + """Abstract interface for session storage.""" + + def save(self, session: SessionData) -> None + def load(self, user_id: str) -> SessionData | None + def delete(self, user_id: str) -> bool + def exists(self, user_id: str) -> bool + def list_users(self) -> list[str] + def clear_all(self) -> int + def update_tokens(...) -> None +``` + +**Implementations**: + +#### MemorySessionStore: +```python +class MemorySessionStore(SessionStore): + """In-memory, lost on restart.""" + _sessions: dict[str, SessionData] # user_id -> session + _lock: threading.RLock # Thread-safe +``` + +- **Pros**: Fast (O(1) lookup), simple +- **Cons**: Lost on restart, single-process only +- **Use case**: Development, testing + +#### FileSessionStore: +```python +class FileSessionStore(SessionStore): + """JSON files on disk.""" + directory: Path # .sessions/ + # Each user = one JSON file: session_user123.json +``` + +- **Pros**: Persists across restarts +- **Cons**: Single-server only, file I/O overhead +- **Use case**: Production (single server) + +#### Future: RedisSessionStore: +```python +class RedisSessionStore(SessionStore): + """Distributed, scalable.""" + # redis.set(f"session:{user_id}", json.dumps(session)) +``` + +- **Pros**: Distributed, scalable, TTL support +- **Cons**: Requires Redis +- **Use case**: Production (multi-server) + +**SessionData Model**: + +```python +@dataclass +class SessionData: + user_id: str + access_token: str | None + refresh_token: str | None + expires_at: float | None # Unix timestamp + scopes: list[str] + user_info: dict # Name, email, etc. + provider: str # "google", "github" + created_at: float + updated_at: float + metadata: dict # Custom app data + + def is_expired(self) -> bool + def needs_refresh(self, buffer_seconds=300) -> bool +``` + +--- + +### 3. Request Enforcement Middleware + +**Location**: `nextmcp/auth/request_middleware.py` + +**Purpose**: Intercept **every** MCP request and enforce auth automatically + +**How it Works**: + +``` +Request arrives + │ + ▼ +┌─────────────────────────────────────┐ +│ Is auth required? │ +│ (Check metadata.requirement) │ +└────┬────────────────────────────┬───┘ + │ │ + │ YES │ NO + ▼ │ +┌─────────────────────────────────┐ │ +│ Extract credentials from request│ │ +│ request.get("auth") │ │ +└────┬────────────────────────────┘ │ + │ │ + ▼ │ +┌─────────────────────────────────┐ │ +│ Check session store │ │ +│ Does user have existing session? │ │ +└────┬────────────────────────┬───┘ │ + │ │ │ + │ YES │ NO │ + ▼ ▼ │ +┌────────────────┐ ┌──────────────┐│ +│ Reuse session │ │ Authenticate ││ +│ Check expired │ │ with provider││ +└────┬───────────┘ └──────┬───────┘│ + │ │ │ + ▼ ▼ │ +┌──────────────────────────────────┐ │ +│ Token expired? │ │ +│ Auto-refresh if enabled │ │ +└────┬─────────────────────────────┘ │ + │ │ + ▼ │ +┌──────────────────────────────────┐ │ +│ Check scopes & permissions │ │ +│ (required_scopes, manifest) │ │ +└────┬─────────────────────────────┘ │ + │ │ + │ All checks passed │ + ▼ ▼ +┌──────────────────────────────────────┐ +│ Inject auth_context into request │ +│ request["_auth_context"] = context │ +└────┬─────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────┐ +│ Call next handler │ +│ Tool function executes │ +└──────────────────────────────────────┘ +``` + +**Code Structure**: + +```python +class AuthEnforcementMiddleware: + def __init__( + self, + provider: AuthProvider, + session_store: SessionStore | None, + metadata: AuthMetadata | None, + manifest: PermissionManifest | None, + auto_refresh_tokens: bool = True, + ): + ... + + async def __call__(self, request: dict, handler: Callable): + # 1. Check if auth required + if self.metadata.requirement == AuthRequirement.NONE: + return await handler(request) + + # 2. Extract credentials + credentials = request.get("auth", {}) + + # 3. Authenticate + auth_result = await self._authenticate(credentials, request) + + # 4. Check authorization + self._check_authorization(auth_result.context, request) + + # 5. Inject context + request["_auth_context"] = auth_result.context + + # 6. Call handler + return await handler(request) +``` + +**Key Methods**: + +```python +async def _authenticate(self, credentials, request): + """Validate credentials, check session, refresh if needed.""" + # 1. Check session store first (reuse) + # 2. If not found, call provider.authenticate() + # 3. Save new session + # 4. Auto-refresh if expiring + # 5. Return AuthResult + +def _check_authorization(self, auth_context, request): + """Check scopes and permissions.""" + # 1. Check required_scopes + # 2. Check manifest (if provided) + # 3. Raise AuthorizationError if denied +``` + +--- + +### 4. OAuth Providers + +**Location**: `nextmcp/auth/oauth.py`, `nextmcp/auth/oauth_providers.py` + +**Base Class**: + +```python +class OAuthProvider(AuthProvider, ABC): + """Base OAuth 2.0 provider with PKCE.""" + + def generate_authorization_url(self, state=None) -> dict: + """Create auth URL with PKCE.""" + pkce = PKCEChallenge.generate() + url = f"{self.config.authorization_url}?..." + return {"url": url, "state": state, "verifier": pkce.verifier} + + async def exchange_code_for_token(self, code, state, verifier): + """Exchange code for access token.""" + # POST to token_url with code + verifier + # Handle both JSON and form-encoded responses + return token_data + + async def refresh_access_token(self, refresh_token): + """Refresh expired token.""" + # POST to token_url with refresh_token + return new_token_data + + @abstractmethod + async def get_user_info(self, access_token): + """Provider-specific user info endpoint.""" + pass +``` + +**Built-in Providers**: + +```python +class GitHubOAuthProvider(OAuthProvider): + """GitHub OAuth with PKCE.""" + # authorization_url: github.com/login/oauth/authorize + # token_url: github.com/login/oauth/access_token + # Returns form-encoded tokens + +class GoogleOAuthProvider(OAuthProvider): + """Google OAuth with PKCE.""" + # authorization_url: accounts.google.com/o/oauth2/v2/auth + # token_url: oauth2.googleapis.com/token + # Returns JSON tokens + # Supports refresh tokens with access_type=offline +``` + +**PKCE Flow**: + +``` +1. Server generates PKCE challenge + verifier = random_43_chars() + challenge = SHA256(verifier) + +2. Server sends challenge in auth URL + redirect to: auth_url?code_challenge=... + +3. User authorizes, provider sends code + +4. Server exchanges code + verifier for token + POST token_url with code + code_verifier + +5. Provider validates: SHA256(verifier) == challenge + Returns access_token + refresh_token +``` + +**Why PKCE**: Secure for public clients (no client secret exposed) + +--- + +## Request Flow + +### Complete Request Lifecycle: + +``` +1. Client sends request + ┌────────────────────────────────────┐ + │ { │ + │ "method": "tools/call", │ + │ "params": {"name": "get_data"}, │ + │ "auth": { │ + │ "access_token": "ya29.a0..." │ + │ } │ + │ } │ + └────────────────────────────────────┘ + │ + ▼ +2. Middleware intercepts + ┌────────────────────────────────────┐ + │ AuthEnforcementMiddleware.__call__ │ + └────────────────────────────────────┘ + │ + ▼ +3. Check auth requirement + if metadata.requirement == NONE: + skip auth + │ + ▼ +4. Extract credentials + access_token = request["auth"]["access_token"] + │ + ▼ +5. Check session store + session = session_store.load_by_token(access_token) + if session: + if session.is_expired(): + reject + if session.needs_refresh(): + auto_refresh + use session + │ + ▼ +6. Or authenticate with provider + result = await provider.authenticate({ + "access_token": access_token + }) + │ + ▼ +7. Save new session + session_store.save(SessionData(...)) + │ + ▼ +8. Check authorization + • Check required_scopes + • Check manifest permissions + │ + ▼ +9. Inject auth context + request["_auth_context"] = AuthContext(...) + │ + ▼ +10. Call tool handler + result = await tool_function() + │ + ▼ +11. Return result to client + ┌────────────────────────────────────┐ + │ { │ + │ "result": {...} │ + │ } │ + └────────────────────────────────────┘ +``` + +--- + +## Data Flow + +### OAuth Token Acquisition: + +``` +┌──────────┐ ┌──────────────┐ +│ Client │ │ OAuth Server │ +│ (Host) │ │ (Google) │ +└────┬─────┘ └──────┬───────┘ + │ │ + │ 1. GET /auth/metadata │ + ├──────────────────────────────► │ + │ Returns auth requirements │ + │ {providers: [google], ...} │ + │◄────────────────────────────────┤ + │ │ + │ 2. Generate auth URL │ + │ provider.generate_authorization_url() + │ Returns: {url, state, verifier} │ + │ │ + │ 3. Open browser to auth URL │ + ├─────────────────────────────────► + │ │ + │ 4. User authorizes │ + │ ┌────────┤ + │ │Authorize│ + │ └────────┤ + │ │ + │ 5. Redirect to callback with code + │◄─────────────────────────────────┤ + │ http://localhost:8080/callback? │ + │ code=abc123&state=xyz │ + │ │ + │ 6. Exchange code for token │ + │ POST /token │ + │ code=abc123 │ + │ code_verifier=... │ + ├─────────────────────────────────► + │ │ + │ 7. Returns tokens │ + │ {access_token, refresh_token} │ + │◄─────────────────────────────────┤ + │ │ + │ 8. Save session │ + │ session_store.save(...) │ + │ │ + │ 9. Make authenticated requests │ + │ {auth: {access_token: ...}} │ + └────────────────────────────────── +``` + +--- + +## Design Decisions + +### Why Middleware Instead of Decorators? + +**Decorators** (old way): +```python +@requires_auth_async(provider) +@requires_scope_async("read:data") +async def tool1(): + pass + +@requires_auth_async(provider) +@requires_scope_async("read:data") +async def tool2(): + pass +``` + +**Middleware** (new way): +```python +server.use(create_auth_middleware( + provider=provider, + required_scopes=["read:data"] +)) + +# All tools automatically protected +``` + +**Advantages**: +1. **DRY**: Don't repeat decorators on every tool +2. **Centralized**: One place to configure auth +3. **Automatic**: Impossible to forget to add auth +4. **Flexible**: Can still use decorators for tool-specific requirements + +### Why Session Store? + +**Without session store**: +- Must re-authenticate every request +- Can't refresh tokens +- Can't support multiple users +- No session persistence + +**With session store**: +- ✅ Authenticate once, reuse session +- ✅ Automatic token refresh +- ✅ Multi-user support +- ✅ Persists across restarts + +### Why Auth Metadata Protocol? + +**Before**: Hosts had to guess or hardcode server auth requirements + +**After**: Hosts can discover auth requirements dynamically + +**Benefits**: +- Standardized auth discovery +- Better UX (show "Connect Google" UI) +- Future-proof (new auth methods just work) + +--- + +## Security Considerations + +### Token Storage + +**Tokens in memory**: MemorySessionStore +- ✅ Fast +- ⚠️ Lost on crash +- ⚠️ Vulnerable to memory dumps + +**Tokens on disk**: FileSessionStore +- ✅ Persists +- ⚠️ Vulnerable to file system access +- **Mitigation**: Proper file permissions (chmod 600) + +**Tokens in Redis**: RedisSessionStore (future) +- ✅ Distributed +- ✅ TTL support +- **Security**: Encrypt Redis connection, use ACLs + +### PKCE + +**Why**: Prevents authorization code interception attacks + +**How**: Verifier proves client initiated the auth flow + +**Without PKCE**: Attacker could intercept code and use it + +**With PKCE**: Attacker can't use code without verifier + +### Token Refresh + +**Automatic refresh**: +- ✅ Good UX (no expiration errors) +- ⚠️ Longer-lived access + +**Manual refresh**: +- ⚠️ Poor UX (users see errors) +- ✅ Shorter-lived access + +**Recommendation**: Use automatic refresh with short-lived access tokens (1 hour) + +### Scope Validation + +**Always validate scopes**: +```python +middleware = create_auth_middleware( + required_scopes=["profile", "email"], + # Don't trust client claims! +) +``` + +**Why**: Client could send fake scopes + +**How**: Middleware validates against actual OAuth token scopes + +--- + +## Performance Characteristics + +### Session Store Performance: + +| Operation | Memory | File | Redis | +|-----------|--------|------|-------| +| Save | O(1) | O(1) | O(1) | +| Load | O(1) | O(1) | O(1) | +| List | O(n) | O(n) | O(n) | +| Delete | O(1) | O(1) | O(1) | + +### Middleware Overhead: + +**Per request**: +1. Session lookup: O(1) - ~0.1ms +2. Token validation: O(1) - ~0.5ms (if cached) +3. Scope check: O(m) where m = number of scopes - ~0.01ms +4. **Total overhead**: ~1ms per request + +**With auto-refresh**: +- Check if token expiring: O(1) - ~0.01ms +- Refresh if needed: ~500ms (network call, rare) + +### Scalability: + +**Single server**: +- MemorySessionStore: 10,000+ concurrent users +- FileSessionStore: 100,000+ users + +**Distributed**: +- RedisSessionStore: Millions of users + +--- + +## Extension Points + +### Custom Session Store: + +```python +class DatabaseSessionStore(SessionStore): + """Store sessions in PostgreSQL.""" + + def __init__(self, db_url): + self.engine = create_engine(db_url) + + def save(self, session): + with self.engine.connect() as conn: + conn.execute( + "INSERT INTO sessions (...) VALUES (...)" + ) +``` + +### Custom OAuth Provider: + +```python +class CustomOAuthProvider(OAuthProvider): + """Your company's OAuth.""" + + async def get_user_info(self, access_token): + # Call your user info endpoint + return user_data + + def extract_user_id(self, user_info): + return user_info["id"] +``` + +### Custom Middleware: + +```python +class AuditMiddleware: + """Log all auth events.""" + + async def __call__(self, request, handler): + auth_context = request.get("_auth_context") + if auth_context: + log_auth_event(auth_context.user_id, request) + return await handler(request) +``` + +--- + +## Summary + +NextMCP's auth architecture provides: + +1. **Discovery** - Servers announce requirements (Auth Metadata) +2. **Enforcement** - Every request is validated (Request Middleware) +3. **Persistence** - Sessions survive restarts (Session Store) +4. **Flexibility** - Pluggable providers and stores +5. **Security** - PKCE, scope validation, token refresh +6. **Performance** - <1ms overhead per request + +The three-layer design (metadata + middleware + storage) creates a complete, production-ready auth system for MCP servers. diff --git a/docs/ENV_SETUP.md b/docs/ENV_SETUP.md new file mode 100644 index 0000000..b24e26e --- /dev/null +++ b/docs/ENV_SETUP.md @@ -0,0 +1,228 @@ +# Environment Variables Setup for OAuth Testing + +Quick guide for setting up environment variables for OAuth integration tests. + +## Quick Start + +```bash +# 1. Create .env file from template +cp .env.example .env + +# Or use the setup script +bash scripts/setup_env.sh + +# 2. Get OAuth tokens using helper script +python examples/auth/oauth_token_helper.py + +# 3. Edit .env file and paste your credentials +nano .env # or vim, code, etc. + +# 4. Load environment variables +export $(cat .env | grep -v '^#' | xargs) + +# 5. Verify +echo $GITHUB_CLIENT_ID +echo $GITHUB_ACCESS_TOKEN + +# 6. Run tests +pytest tests/test_oauth_integration.py -v -m integration +``` + +## .env File Format + +The `.env.example` file provides a template with all required variables: + +```bash +# GitHub OAuth +GITHUB_CLIENT_ID=your_github_client_id_here +GITHUB_CLIENT_SECRET=your_github_client_secret_here +GITHUB_ACCESS_TOKEN=gho_your_github_access_token_here + +# Google OAuth +GOOGLE_CLIENT_ID=your_google_client_id.apps.googleusercontent.com +GOOGLE_CLIENT_SECRET=your_google_client_secret_here +GOOGLE_ACCESS_TOKEN=ya29.your_google_access_token_here +GOOGLE_REFRESH_TOKEN=1//your_google_refresh_token_here +``` + +## Loading Environment Variables + +### Option 1: Shell Export (Temporary) + +Loads variables for current shell session only: + +```bash +# Load all variables (filters comments) +export $(cat .env | grep -v '^#' | xargs) + +# Simpler version (includes comments in variable names, may cause issues) +export $(cat .env | xargs) +``` + +### Option 2: Shell Configuration (Persistent) + +Add to `~/.bashrc` or `~/.zshrc` to load automatically: + +```bash +# Add to ~/.bashrc or ~/.zshrc +if [ -f ~/path/to/nextmcp/.env ]; then + export $(cat ~/path/to/nextmcp/.env | grep -v '^#' | xargs) +fi +``` + +### Option 3: Direnv (Automatic) + +Install [direnv](https://direnv.net/) for automatic loading: + +```bash +# Install direnv +brew install direnv # macOS +# or +sudo apt install direnv # Linux + +# Add to shell config +echo 'eval "$(direnv hook bash)"' >> ~/.bashrc # for bash +# or +echo 'eval "$(direnv hook zsh)"' >> ~/.zshrc # for zsh + +# Allow .env in project directory +cd /path/to/nextmcp +direnv allow . + +# Now .env loads automatically when you cd into the directory +``` + +### Option 4: Python dotenv + +Load in Python code: + +```python +from dotenv import load_dotenv + +load_dotenv() # Loads .env from current directory + +# Now os.getenv() will find the variables +import os +client_id = os.getenv("GITHUB_CLIENT_ID") +``` + +## Security Best Practices + +### ✅ DO: +- Keep `.env` files local only (never commit to git) +- Use different `.env` files for different environments (`.env.test`, `.env.production`) +- Rotate tokens regularly +- Use minimal scopes needed for testing +- Add `.env*` to `.gitignore` (already done) + +### ❌ DON'T: +- Commit `.env` files to git (already ignored) +- Share `.env` files via email or chat +- Use production credentials for testing +- Grant unnecessary OAuth scopes +- Keep tokens that aren't being used + +## Troubleshooting + +### Variables Not Loading + +```bash +# Check if .env exists +ls -la .env + +# Check .env contents (be careful - contains secrets!) +cat .env + +# Verify export command worked +echo $GITHUB_CLIENT_ID + +# If empty, manually export one variable to test +export GITHUB_CLIENT_ID="test_value" +echo $GITHUB_CLIENT_ID +``` + +### Tests Still Skipping + +```bash +# Run setup instruction test to see status +pytest tests/test_oauth_integration.py::test_show_setup_instructions -v -s + +# This will show which variables are set/missing +``` + +### Invalid Tokens + +Access tokens expire: +- **GitHub**: Personal access tokens don't expire (until revoked) +- **Google**: Access tokens expire after 1 hour + +Re-run the helper script to get fresh tokens: + +```bash +python examples/auth/oauth_token_helper.py --provider google +``` + +### Permission Denied + +If the export command fails: + +```bash +# Check file permissions +ls -l .env + +# Should be readable by you +# If not: +chmod 600 .env +``` + +## Environment Variables Reference + +| Variable | Required For | How to Get | +|----------|-------------|------------| +| `GITHUB_CLIENT_ID` | GitHub URL generation | GitHub Settings → Developer Settings → OAuth Apps | +| `GITHUB_CLIENT_SECRET` | GitHub URL generation | Same as above | +| `GITHUB_ACCESS_TOKEN` | GitHub API tests | Run `oauth_token_helper.py --provider github` | +| `GOOGLE_CLIENT_ID` | Google URL generation | Google Cloud Console → Credentials | +| `GOOGLE_CLIENT_SECRET` | Google URL generation | Same as above | +| `GOOGLE_ACCESS_TOKEN` | Google API tests | Run `oauth_token_helper.py --provider google` | +| `GOOGLE_REFRESH_TOKEN` | Token refresh tests | Same as above (issued on first auth) | + +## Alternative: GitHub Actions Secrets + +For CI/CD, use GitHub Actions secrets instead of .env files: + +```yaml +# .github/workflows/integration-tests.yml +name: Integration Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + pip install -e ".[dev,oauth]" + + - name: Run integration tests + env: + GITHUB_CLIENT_ID: ${{ secrets.GITHUB_CLIENT_ID }} + GITHUB_CLIENT_SECRET: ${{ secrets.GITHUB_CLIENT_SECRET }} + GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_ACCESS_TOKEN }} + run: | + pytest tests/test_oauth_integration.py -v -m integration +``` + +Then add secrets in: Repository Settings → Secrets and variables → Actions + +## See Also + +- [Complete OAuth Setup Guide](OAUTH_TESTING_SETUP.md) - Detailed instructions +- [OAuth Examples](../examples/auth/) - Example implementations +- [Integration Tests](../tests/test_oauth_integration.py) - Test source code diff --git a/docs/HOST_INTEGRATION.md b/docs/HOST_INTEGRATION.md new file mode 100644 index 0000000..ad44710 --- /dev/null +++ b/docs/HOST_INTEGRATION.md @@ -0,0 +1,733 @@ +# Host Integration Guide for NextMCP Auth + +This guide is for **host developers** (Claude Desktop, Cursor, Windsurf, Zed, etc.) who want to support NextMCP's authentication system in their applications. + +--- + +## Table of Contents + +1. [Overview](#overview) +2. [Quick Start](#quick-start) +3. [Discovery: Reading Auth Metadata](#discovery-reading-auth-metadata) +4. [OAuth Flow Implementation](#oauth-flow-implementation) +5. [Sending Authenticated Requests](#sending-authenticated-requests) +6. [Token Management](#token-management) +7. [Error Handling](#error-handling) +8. [UI/UX Recommendations](#uiux-recommendations) +9. [Implementation Checklist](#implementation-checklist) + +--- + +## Overview + +### What is NextMCP Auth? + +NextMCP provides a standard way for MCP servers to require authentication. Servers can: +- Announce their auth requirements (OAuth providers, scopes, permissions) +- Enforce authentication automatically on all requests +- Manage user sessions with token refresh + +### What Hosts Need to Do: + +1. **Discovery**: Read server's auth metadata +2. **OAuth Flow**: Implement OAuth 2.0 with PKCE +3. **Token Storage**: Store and manage user tokens +4. **Request Injection**: Include auth credentials in requests +5. **Error Handling**: Handle auth failures gracefully + +--- + +## Quick Start + +### Minimum Viable Integration: + +```typescript +// 1. Check if server requires auth +const metadata = await server.call("get_auth_metadata"); + +if (metadata.auth.requirement === "required") { + // 2. Get OAuth provider info + const provider = metadata.auth.providers[0]; // e.g., Google + + // 3. Run OAuth flow (see detailed section) + const tokens = await runOAuthFlow(provider); + + // 4. Store tokens + await storage.setTokens(serverId, tokens); +} + +// 5. Make authenticated requests +const result = await server.call("some_tool", { + auth: { + access_token: tokens.access_token + } +}); +``` + +--- + +## Discovery: Reading Auth Metadata + +### Step 1: Check if Server Exposes Metadata + +Not all servers will have this tool, so check: + +```typescript +const tools = await server.listTools(); +const hasAuthMetadata = tools.some(t => t.name === "get_auth_metadata"); + +if (hasAuthMetadata) { + const metadata = await server.call("get_auth_metadata"); + // Process metadata... +} +``` + +### Step 2: Parse Auth Metadata + +The metadata follows this schema: + +```typescript +interface AuthMetadata { + requirement: "required" | "optional" | "none"; + providers: AuthProvider[]; + required_scopes: string[]; + optional_scopes: string[]; + permissions: string[]; + roles: string[]; + supports_multi_user: boolean; + session_management: "server-side" | "client-side" | "stateless"; + token_refresh_enabled: boolean; +} + +interface AuthProvider { + name: string; // "google", "github" + type: string; // "oauth2" + flows: string[]; // ["oauth2-pkce"] + authorization_url: string; + token_url: string; + scopes: string[]; + supports_refresh: boolean; + supports_pkce: boolean; +} +``` + +### Example Metadata Response: + +```json +{ + "requirement": "required", + "providers": [ + { + "name": "google", + "type": "oauth2", + "flows": ["oauth2-pkce"], + "authorization_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "scopes": ["openid", "email", "profile"], + "supports_refresh": true, + "supports_pkce": true + } + ], + "required_scopes": ["openid", "email"], + "optional_scopes": ["profile"], + "supports_multi_user": true, + "token_refresh_enabled": true +} +``` + +### Step 3: Decision Logic + +```typescript +function handleAuthMetadata(metadata: AuthMetadata) { + switch (metadata.requirement) { + case "none": + // No auth needed, proceed normally + return; + + case "optional": + // Show "Sign in to unlock features" UI + showOptionalAuthPrompt(metadata); + break; + + case "required": + // Block until user authenticates + showRequiredAuthPrompt(metadata); + break; + } +} +``` + +--- + +## OAuth Flow Implementation + +### Overview: OAuth 2.0 with PKCE + +``` +1. Host generates PKCE verifier + challenge +2. Host opens browser to authorization_url +3. User authorizes +4. Provider redirects back with code +5. Host exchanges code + verifier for tokens +6. Host stores tokens +``` + +### Step 1: Generate PKCE Challenge + +```typescript +import crypto from 'crypto'; + +function generatePKCE() { + // Generate random verifier (43-128 chars) + const verifier = crypto.randomBytes(32) + .toString('base64url'); + + // Generate SHA256 challenge + const challenge = crypto.createHash('sha256') + .update(verifier) + .digest('base64url'); + + return { verifier, challenge }; +} +``` + +### Step 2: Build Authorization URL + +```typescript +function buildAuthUrl(provider: AuthProvider, pkce: PKCE): string { + const state = crypto.randomBytes(16).toString('hex'); + + const params = new URLSearchParams({ + client_id: provider.client_id, // From your OAuth app + redirect_uri: "http://localhost:8080/oauth/callback", + response_type: "code", + state: state, + code_challenge: pkce.challenge, + code_challenge_method: "S256", + scope: provider.scopes.join(" "), + }); + + // For Google, add access_type=offline for refresh tokens + if (provider.name === "google" && provider.supports_refresh) { + params.set("access_type", "offline"); + params.set("prompt", "consent"); + } + + return `${provider.authorization_url}?${params}`; +} +``` + +### Step 3: Handle OAuth Callback + +You need a local HTTP server to receive the redirect: + +```typescript +import http from 'http'; + +async function waitForCallback(): Promise<{ code: string; state: string }> { + return new Promise((resolve, reject) => { + const server = http.createServer((req, res) => { + const url = new URL(req.url!, 'http://localhost:8080'); + + if (url.pathname === '/oauth/callback') { + const code = url.searchParams.get('code'); + const error = url.searchParams.get('error'); + const state = url.searchParams.get('state'); + + if (error) { + res.writeHead(400); + res.end(`Error: ${error}`); + reject(new Error(error)); + } else if (code && state) { + res.writeHead(200); + res.end('✅ Authorization successful! You can close this window.'); + resolve({ code, state }); + } + + server.close(); + } + }); + + server.listen(8080); + }); +} +``` + +### Step 4: Exchange Code for Tokens + +```typescript +async function exchangeCodeForTokens( + provider: AuthProvider, + code: string, + verifier: string, +): Promise { + const response = await fetch(provider.token_url, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'authorization_code', + code: code, + redirect_uri: 'http://localhost:8080/oauth/callback', + client_id: provider.client_id, + code_verifier: verifier, + // client_secret only if you have one (confidential clients) + }), + }); + + if (!response.ok) { + throw new Error(`Token exchange failed: ${response.statusText}`); + } + + // GitHub returns form-encoded, Google returns JSON + const contentType = response.headers.get('content-type'); + + if (contentType?.includes('application/json')) { + return await response.json(); + } else { + // Parse form-encoded response + const text = await response.text(); + const params = new URLSearchParams(text); + return { + access_token: params.get('access_token'), + token_type: params.get('token_type'), + expires_in: parseInt(params.get('expires_in') || '3600'), + refresh_token: params.get('refresh_token'), + scope: params.get('scope'), + }; + } +} +``` + +### Complete Flow: + +```typescript +async function runOAuthFlow(provider: AuthProvider): Promise { + // 1. Generate PKCE + const pkce = generatePKCE(); + const state = crypto.randomBytes(16).toString('hex'); + + // 2. Build auth URL + const authUrl = buildAuthUrl(provider, pkce, state); + + // 3. Open browser + await openBrowser(authUrl); + + // 4. Start local server and wait for callback + const { code, state: returnedState } = await waitForCallback(); + + // 5. Validate state (CSRF protection) + if (returnedState !== state) { + throw new Error('State mismatch - possible CSRF attack'); + } + + // 6. Exchange code for tokens + const tokens = await exchangeCodeForTokens(provider, code, pkce.verifier); + + return tokens; +} +``` + +--- + +## Sending Authenticated Requests + +### Include Auth in Every Request + +Once you have tokens, include them in the request: + +```typescript +const request = { + method: "tools/call", + params: { + name: "get_user_data", + arguments: { + user_id: "123" + } + }, + auth: { + access_token: tokens.access_token, + // Optional: include other token info + refresh_token: tokens.refresh_token, + scopes: tokens.scope?.split(' '), + } +}; + +const response = await server.send(request); +``` + +### Where to Inject Auth: + +NextMCP middleware looks for `request["auth"]`, so: + +```typescript +// Correct ✅ +{ + "method": "tools/call", + "params": {...}, + "auth": { + "access_token": "ya29.a0..." + } +} + +// Incorrect ❌ (won't work) +{ + "method": "tools/call", + "params": {...}, + "headers": { + "Authorization": "Bearer ya29.a0..." + } +} +``` + +--- + +## Token Management + +### Store Tokens Securely + +```typescript +interface StoredTokens { + access_token: string; + refresh_token?: string; + expires_at: number; // Unix timestamp + scope: string; + provider: string; + user_info?: { + email: string; + name: string; + }; +} + +class TokenStore { + async saveTokens(serverId: string, tokens: StoredTokens) { + // Use OS keychain, encrypted storage, etc. + await keychain.set(`mcp:${serverId}`, JSON.stringify(tokens)); + } + + async loadTokens(serverId: string): Promise { + const data = await keychain.get(`mcp:${serverId}`); + return data ? JSON.parse(data) : null; + } + + async deleteTokens(serverId: string) { + await keychain.delete(`mcp:${serverId}`); + } +} +``` + +### Check Token Expiration + +```typescript +function isTokenExpired(tokens: StoredTokens): boolean { + if (!tokens.expires_at) return false; + return Date.now() / 1000 >= tokens.expires_at; +} + +function needsRefresh(tokens: StoredTokens, bufferSeconds = 300): boolean { + if (!tokens.expires_at) return false; + return Date.now() / 1000 >= tokens.expires_at - bufferSeconds; +} +``` + +### Refresh Tokens + +```typescript +async function refreshTokens( + provider: AuthProvider, + refreshToken: string, +): Promise { + const response = await fetch(provider.token_url, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'refresh_token', + refresh_token: refreshToken, + client_id: provider.client_id, + }), + }); + + if (!response.ok) { + // Refresh failed - need to re-authenticate + throw new Error('Token refresh failed'); + } + + return await response.json(); +} +``` + +### Auto-Refresh Strategy + +```typescript +async function getValidToken(serverId: string): Promise { + const stored = await tokenStore.loadTokens(serverId); + + if (!stored) { + // No tokens - need to authenticate + throw new Error('Not authenticated'); + } + + if (isTokenExpired(stored)) { + // Token expired and no refresh token + if (!stored.refresh_token) { + throw new Error('Token expired - please re-authenticate'); + } + + // Try to refresh + try { + const newTokens = await refreshTokens(provider, stored.refresh_token); + await tokenStore.saveTokens(serverId, { + ...newTokens, + expires_at: Date.now() / 1000 + newTokens.expires_in, + }); + return newTokens.access_token; + } catch (e) { + // Refresh failed - need to re-authenticate + throw new Error('Token refresh failed - please re-authenticate'); + } + } + + if (needsRefresh(stored)) { + // Preemptively refresh (don't wait for request to fail) + refreshTokens(provider, stored.refresh_token!) + .then(newTokens => tokenStore.saveTokens(serverId, { + ...newTokens, + expires_at: Date.now() / 1000 + newTokens.expires_in, + })) + .catch(() => { + // Refresh failed, but current token still valid + // Will be handled on next check + }); + } + + return stored.access_token; +} +``` + +--- + +## Error Handling + +### Auth Error Types + +NextMCP servers return structured errors: + +```typescript +interface AuthError { + error: "authentication_required" | "authorization_denied" | "token_expired"; + message: string; + required_scopes?: string[]; + providers?: AuthProvider[]; +} +``` + +### Handle Authentication Errors + +```typescript +async function handleRequest(request: any) { + try { + const response = await server.send(request); + return response; + } catch (error) { + if (error.error === "authentication_required") { + // Show "Sign in required" UI + const tokens = await promptUserToSignIn(error.providers); + // Retry request + return await server.send({ + ...request, + auth: { access_token: tokens.access_token } + }); + } + + if (error.error === "authorization_denied") { + // User lacks required scopes + showError(`Missing permissions: ${error.required_scopes.join(', ')}`); + // Optionally: re-run OAuth flow with additional scopes + } + + if (error.error === "token_expired") { + // Token expired - try refresh + const tokens = await refreshOrReauth(serverId); + return await server.send({ + ...request, + auth: { access_token: tokens.access_token } + }); + } + + throw error; + } +} +``` + +--- + +## UI/UX Recommendations + +### 1. Server Connection Flow + +``` +User adds server + │ + ▼ +Check auth metadata + │ + ├─► No auth required + │ → Connect immediately + │ + └─► Auth required + │ + ▼ + Show auth prompt: + "This server requires authentication" + [Connect with Google] [Connect with GitHub] + │ + ▼ + Run OAuth flow + │ + ▼ + Store tokens + │ + ▼ + Server connected ✓ +``` + +### 2. Auth Prompt Design + +**For Required Auth**: +``` +┌─────────────────────────────────────────┐ +│ 📡 Connect to "My MCP Server" │ +│ │ +│ This server requires authentication │ +│ to protect your data. │ +│ │ +│ Required permissions: │ +│ • Read your profile │ +│ • Access your files │ +│ │ +│ [🔐 Connect with Google] │ +│ [🔐 Connect with GitHub] │ +│ │ +│ [ Cancel ] │ +└─────────────────────────────────────────┘ +``` + +**For Optional Auth**: +``` +┌─────────────────────────────────────────┐ +│ 📡 "My MCP Server" │ +│ │ +│ ✓ Connected (Limited Features) │ +│ │ +│ Sign in to unlock: │ +│ • Personalized responses │ +│ • Save your preferences │ +│ • Access premium features │ +│ │ +│ [Sign in] [Maybe later] │ +└─────────────────────────────────────────┘ +``` + +### 3. Token Status Indicator + +``` +Server: My MCP Server +Status: ✓ Authenticated as user@example.com +Token expires: in 45 minutes +[Refresh] [Sign out] +``` + +### 4. Error Messages + +**Good**: +``` +❌ Authentication expired + +Your session has expired. Please sign in again. + +[Sign in with Google] +``` + +**Bad**: +``` +Error: 401 Unauthorized +``` + +--- + +## Implementation Checklist + +### Core Features: +- [ ] Read auth metadata from servers +- [ ] Implement OAuth 2.0 with PKCE +- [ ] Local callback server (http://localhost:8080) +- [ ] Token storage (encrypted/keychain) +- [ ] Inject auth in requests +- [ ] Handle auth errors +- [ ] Token refresh logic + +### UX Features: +- [ ] "Connect" UI for OAuth providers +- [ ] Auth status indicator +- [ ] Token expiration warnings +- [ ] Re-authentication flow +- [ ] Sign out functionality +- [ ] Multi-account support (optional) + +### Security: +- [ ] State parameter validation (CSRF) +- [ ] Secure token storage +- [ ] HTTPS for callback URL (production) +- [ ] Scope validation + +### Polish: +- [ ] Provider logos (Google, GitHub) +- [ ] Loading states during OAuth +- [ ] Error recovery +- [ ] Offline mode handling + +--- + +## Reference Implementation + +For a complete reference implementation, see: + +- **Python Client**: `examples/auth/oauth_token_helper.py` +- **Tests**: `tests/test_oauth_integration.py` + +--- + +## Future Enhancements + +These are not required now but may be added in future: + +1. **Multiple Accounts**: Support multiple users per server +2. **Account Switching**: Switch between Google/GitHub accounts +3. **Permission Negotiation**: Dynamic scope requests +4. **SSO Integration**: Enterprise SSO support + +--- + +## Getting Help + +- NextMCP Issues: https://github.com/anthropics/nextmcp/issues +- MCP Specification: https://modelcontextprotocol.io +- OAuth 2.0 Spec: https://oauth.net/2/ + +--- + +## Summary + +To support NextMCP auth in your host application: + +1. **Read** auth metadata to discover requirements +2. **Implement** OAuth 2.0 with PKCE for authorization +3. **Store** tokens securely +4. **Inject** auth credentials in every request +5. **Refresh** tokens automatically +6. **Handle** errors gracefully with good UX + +The auth system is designed to be straightforward to integrate while providing enterprise-grade security. diff --git a/docs/MIGRATION_GUIDE.md b/docs/MIGRATION_GUIDE.md new file mode 100644 index 0000000..08322ea --- /dev/null +++ b/docs/MIGRATION_GUIDE.md @@ -0,0 +1,594 @@ +# Migration Guide: Adding Auth to Your MCP Server + +This guide shows you how to add NextMCP authentication to your existing MCP server or migrate from FastMCP to NextMCP with auth. + +--- + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Adding OAuth to Existing Servers](#adding-oauth-to-existing-servers) +3. [Migration from FastMCP](#migration-from-fastmcp) +4. [Adding Session Management](#adding-session-management) +5. [Migrating from Decorators to Middleware](#migrating-from-decorators-to-middleware) +6. [Common Patterns](#common-patterns) +7. [Troubleshooting](#troubleshooting) + +--- + +## Quick Start + +### Before (No Auth): + +```python +from fastmcp import FastMCP + +mcp = FastMCP("My Server") + +@mcp.tool() +def get_user_data(user_id: str) -> dict: + """Get user data - anyone can call this!""" + return {"user_id": user_id, "data": "sensitive info"} +``` + +### After (With OAuth): + +```python +from fastmcp import FastMCP +from nextmcp.auth import GoogleOAuthProvider, create_auth_middleware +from nextmcp.session import MemorySessionStore +from nextmcp.protocol import AuthRequirement + +mcp = FastMCP("My Server") + +# Set up OAuth +google = GoogleOAuthProvider( + client_id="your-client-id", + client_secret="your-client-secret", +) + +# Enable auth enforcement +auth_middleware = create_auth_middleware( + provider=google, + requirement=AuthRequirement.REQUIRED, + session_store=MemorySessionStore(), + required_scopes=["profile", "email"], +) + +mcp.use(auth_middleware) + +@mcp.tool() +def get_user_data(user_id: str) -> dict: + """Get user data - now requires OAuth authentication!""" + # Request automatically has _auth_context injected + return {"user_id": user_id, "data": "sensitive info"} +``` + +**That's it!** Your server now requires OAuth authentication for all requests. + +--- + +## Adding OAuth to Existing Servers + +### Step 1: Install Dependencies + +If using OAuth, ensure you have the oauth extras: + +```bash +pip install "nextmcp[oauth]" +``` + +### Step 2: Choose Your OAuth Provider + +NextMCP includes two built-in providers: + +#### GitHub OAuth: + +```python +from nextmcp.auth import GitHubOAuthProvider + +github = GitHubOAuthProvider( + client_id="your_github_client_id", + client_secret="your_github_client_secret", + redirect_uri="http://localhost:8080/oauth/callback", # Optional + scope=["read:user", "repo"], # Optional +) +``` + +**Get credentials**: https://github.com/settings/developers + +#### Google OAuth: + +```python +from nextmcp.auth import GoogleOAuthProvider + +google = GoogleOAuthProvider( + client_id="your_google_client_id", + client_secret="your_google_client_secret", + redirect_uri="http://localhost:8080/oauth/callback", # Optional + scope=["openid", "email", "profile"], # Optional +) +``` + +**Get credentials**: https://console.cloud.google.com + +### Step 3: Add Session Store + +Choose a session store based on your needs: + +#### Development (In-Memory): + +```python +from nextmcp.session import MemorySessionStore + +session_store = MemorySessionStore() +``` + +**Pros**: Fast, simple +**Cons**: Lost on restart, not distributed + +#### Production (File-Based): + +```python +from nextmcp.session import FileSessionStore + +session_store = FileSessionStore(".sessions") +``` + +**Pros**: Persists across restarts +**Cons**: Single-server only + +#### Future (Redis - Coming Soon): + +```python +# from nextmcp.session import RedisSessionStore +# session_store = RedisSessionStore("redis://localhost:6379") +``` + +**Pros**: Distributed, scalable +**Cons**: Requires Redis + +### Step 4: Apply Middleware + +```python +from nextmcp.auth import create_auth_middleware +from nextmcp.protocol import AuthRequirement + +middleware = create_auth_middleware( + provider=google, # or github + requirement=AuthRequirement.REQUIRED, + session_store=session_store, + required_scopes=["profile", "email"], +) + +# Apply to your server +mcp.use(middleware) +``` + +### Step 5: Expose Auth Metadata (Optional but Recommended) + +Let clients discover your auth requirements: + +```python +from nextmcp.protocol import AuthMetadata, AuthFlowType + +# Build metadata +metadata = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + supports_multi_user=True, + token_refresh_enabled=True, +) + +metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["openid", "email", "profile"], + supports_refresh=True, +) + +# Expose via an endpoint +@mcp.tool() +def get_auth_metadata() -> dict: + """Get server authentication requirements.""" + return metadata.to_dict() +``` + +--- + +## Migration from FastMCP + +If you're using FastMCP and want to add auth: + +### Pattern 1: Add Auth to All Tools + +```python +# Before +from fastmcp import FastMCP + +mcp = FastMCP("My Server") + +@mcp.tool() +def tool1(): + pass + +@mcp.tool() +def tool2(): + pass + +# After - Add middleware +from nextmcp.auth import create_auth_middleware, GoogleOAuthProvider +from nextmcp.session import MemorySessionStore + +google = GoogleOAuthProvider(client_id="...", client_secret="...") +mcp.use(create_auth_middleware( + provider=google, + session_store=MemorySessionStore(), +)) + +# All tools now require auth automatically! +``` + +### Pattern 2: Mix Public and Protected Tools + +```python +from nextmcp.protocol import AuthRequirement + +# Create middleware with OPTIONAL auth +middleware = create_auth_middleware( + provider=google, + requirement=AuthRequirement.OPTIONAL, + session_store=MemorySessionStore(), +) + +mcp.use(middleware) + +@mcp.tool() +def public_tool(): + """Anyone can call this.""" + return "public data" + +@mcp.tool() +def protected_tool(): + """Requires auth but middleware handles it.""" + # Check if authenticated + # (would need to access _auth_context from request) + return "protected data" +``` + +For fine-grained control, use decorators on specific tools: + +```python +from nextmcp.auth import requires_auth_async + +@mcp.tool() +async def public_tool(): + """No auth needed.""" + return "public" + +@mcp.tool() +@requires_auth_async(provider=google) +async def protected_tool(auth: AuthContext): + """This specific tool requires auth.""" + return f"Hello {auth.username}" +``` + +--- + +## Adding Session Management + +### Basic Setup: + +```python +from nextmcp.session import FileSessionStore, SessionData + +# Create session store +session_store = FileSessionStore(".sessions") + +# Middleware automatically manages sessions +middleware = create_auth_middleware( + provider=google, + session_store=session_store, + auto_refresh_tokens=True, # Automatically refresh expiring tokens +) + +mcp.use(middleware) +``` + +### Manual Session Management: + +```python +import time +from nextmcp.session import SessionData + +# Create session manually +session = SessionData( + user_id="user123", + access_token="ya29.a0...", + refresh_token="1//01...", + expires_at=time.time() + 3600, + scopes=["profile", "email"], + user_info={"email": "user@example.com"}, + provider="google", +) + +session_store.save(session) + +# Load session +loaded = session_store.load("user123") + +# Check expiration +if loaded.needs_refresh(): + # Token expires soon, refresh it + pass + +# Clean up expired sessions +session_store.cleanup_expired() +``` + +--- + +## Migrating from Decorators to Middleware + +If you're using decorator-based auth, consider migrating to middleware for automatic enforcement: + +### Before (Decorator-Based): + +```python +from nextmcp.auth import requires_auth_async, requires_scope_async + +@mcp.tool() +@requires_auth_async(provider=google) +@requires_scope_async("read:data") +async def tool1(auth: AuthContext): + return "data" + +@mcp.tool() +@requires_auth_async(provider=google) +@requires_scope_async("read:data") +async def tool2(auth: AuthContext): + return "data" + +# Every tool needs decorators - tedious! +``` + +### After (Middleware-Based): + +```python +from nextmcp.auth import create_auth_middleware + +# One-time setup +middleware = create_auth_middleware( + provider=google, + required_scopes=["read:data"], +) + +mcp.use(middleware) + +# All tools automatically protected! +@mcp.tool() +def tool1(): + return "data" + +@mcp.tool() +def tool2(): + return "data" +``` + +### When to Use Decorators: + +Use decorators when different tools need different auth: + +```python +# Use middleware for base auth +middleware = create_auth_middleware(provider=google) +mcp.use(middleware) + +# Use decorators for tool-specific requirements +@mcp.tool() +@requires_scope_async("basic:read") +async def basic_tool(auth: AuthContext): + return "basic data" + +@mcp.tool() +@requires_scope_async("admin:write") +async def admin_tool(auth: AuthContext): + return "admin data" +``` + +--- + +## Common Patterns + +### Pattern 1: Multi-Provider Support + +```python +from nextmcp.auth import GitHubOAuthProvider, GoogleOAuthProvider + +# Set up both providers +github = GitHubOAuthProvider(client_id="...", client_secret="...") +google = GoogleOAuthProvider(client_id="...", client_secret="...") + +# You can switch providers or use different ones for different tools +# (See examples/auth/multi_provider_server.py for full example) +``` + +### Pattern 2: Per-Tool Permissions + +```python +from nextmcp.auth import PermissionManifest + +# Define permissions +manifest = PermissionManifest() +manifest.define_tool_permission("read_files", scopes=["files:read"]) +manifest.define_tool_permission("write_files", scopes=["files:write"]) +manifest.define_tool_permission("admin_panel", roles=["admin"]) + +# Apply manifest to middleware +middleware = AuthEnforcementMiddleware( + provider=google, + session_store=session_store, + manifest=manifest, +) + +mcp.use(middleware) +``` + +### Pattern 3: Custom User Data + +```python +# Store custom data in sessions +session.metadata = { + "preferences": {"theme": "dark"}, + "subscription": "premium", + "last_login": time.time(), +} + +session_store.save(session) +``` + +### Pattern 4: Token Refresh + +```python +# Automatic refresh (recommended) +middleware = create_auth_middleware( + provider=google, + session_store=session_store, + auto_refresh_tokens=True, # Enabled by default +) + +# Manual refresh +from nextmcp.auth.oauth import OAuthProvider + +session = session_store.load("user123") +if session.needs_refresh() and session.refresh_token: + # Refresh token + token_data = await provider.refresh_access_token(session.refresh_token) + + # Update session + session_store.update_tokens( + user_id="user123", + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + expires_in=token_data.get("expires_in"), + ) +``` + +--- + +## Troubleshooting + +### Problem: "No credentials provided" + +**Solution**: Ensure client sends auth credentials in request: + +```python +# Client must send: +request = { + "method": "tools/call", + "params": {"name": "my_tool"}, + "auth": { + "access_token": "ya29.a0...", + } +} +``` + +### Problem: "Authentication failed" + +**Possible causes**: +1. Invalid OAuth token +2. Token expired +3. Wrong provider credentials + +**Debug**: +```python +# Test OAuth provider directly +result = await provider.authenticate({"access_token": "..."}) +print(result.success, result.error) +``` + +### Problem: "Missing required scopes" + +**Solution**: User needs to re-authorize with additional scopes: + +```python +# Generate new auth URL with required scopes +auth_url_data = provider.generate_authorization_url() +print(auth_url_data["url"]) +# User must visit this URL +``` + +### Problem: Sessions not persisting + +**Check**: +```python +# FileSessionStore - check directory exists +session_store = FileSessionStore(".sessions") +print(list(session_store.directory.glob("session_*.json"))) + +# MemorySessionStore - sessions lost on restart (expected) +``` + +### Problem: "Token expired" errors + +**Solutions**: +1. Enable auto-refresh: + ```python + middleware = create_auth_middleware(auto_refresh_tokens=True) + ``` + +2. Ensure refresh tokens are saved: + ```python + # Check session has refresh token + session = session_store.load("user123") + print(session.refresh_token) # Should not be None + ``` + +3. Re-authenticate user if refresh fails + +--- + +## Best Practices + +1. **Always use HTTPS in production** - OAuth tokens are sensitive + +2. **Use FileSessionStore or Redis in production** - MemorySessionStore loses sessions on restart + +3. **Enable auto-refresh** - Users won't see token expiration errors + +4. **Validate scopes** - Request minimum scopes needed + +5. **Handle errors gracefully** - Show clear messages to users + +6. **Clean up expired sessions** - Run periodic cleanup: + ```python + import asyncio + + async def cleanup_loop(): + while True: + await asyncio.sleep(3600) # Every hour + session_store.cleanup_expired() + ``` + +7. **Expose auth metadata** - Let clients discover your auth requirements + +8. **Test with real OAuth** - Use integration tests with actual credentials + +--- + +## Next Steps + +- See [ARCHITECTURE.md](ARCHITECTURE.md) for how auth works internally +- See [HOST_INTEGRATION.md](HOST_INTEGRATION.md) for host integration +- Check [examples/auth/](../examples/auth/) for complete examples +- Read [OAuth Testing Setup Guide](OAUTH_TESTING_SETUP.md) for testing + +--- + +## Need Help? + +- Check examples: `examples/auth/` +- Read tests: `tests/test_request_middleware.py` +- Open an issue: https://github.com/anthropics/nextmcp/issues diff --git a/docs/OAUTH_TESTING_SETUP.md b/docs/OAUTH_TESTING_SETUP.md new file mode 100644 index 0000000..97aa95c --- /dev/null +++ b/docs/OAUTH_TESTING_SETUP.md @@ -0,0 +1,436 @@ +# OAuth Integration Testing Setup + +This guide explains how to set up OAuth credentials and obtain access tokens for running integration tests. + +## Overview + +The integration tests (`tests/test_oauth_integration.py`) verify that the OAuth implementation works with real GitHub and Google APIs. To run these tests, you need: + +1. **OAuth App Credentials** - Client ID and Secret from GitHub/Google +2. **Access Tokens** - Pre-obtained tokens for testing authenticated endpoints +3. **Environment Variables** - Configuration for the tests + +## Quick Start + +```bash +# 1. Get OAuth credentials (see detailed instructions below) +# 2. Use the helper script to obtain tokens +python examples/auth/oauth_token_helper.py + +# 3. Set environment variables +export GITHUB_CLIENT_ID="your_client_id" +export GITHUB_CLIENT_SECRET="your_client_secret" +export GITHUB_ACCESS_TOKEN="gho_..." + +export GOOGLE_CLIENT_ID="your_client_id.apps.googleusercontent.com" +export GOOGLE_CLIENT_SECRET="your_client_secret" +export GOOGLE_ACCESS_TOKEN="ya29..." +export GOOGLE_REFRESH_TOKEN="1//..." + +# 4. Run integration tests +pytest tests/test_oauth_integration.py -v -m integration +``` + +--- + +## GitHub OAuth Setup + +### Step 1: Create a GitHub OAuth App + +1. Go to **GitHub Settings** → **Developer settings** → **OAuth Apps** + - Direct link: https://github.com/settings/developers + +2. Click **"New OAuth App"** + +3. Fill in the application details: + ``` + Application name: NextMCP OAuth Testing + Homepage URL: http://localhost:8080 + Authorization callback URL: http://localhost:8080/oauth/callback + ``` + +4. Click **"Register application"** + +5. You'll see your **Client ID** - copy this + +6. Click **"Generate a new client secret"** and copy the secret + - ⚠️ Save this immediately - you won't be able to see it again! + +### Step 2: Get GitHub Access Token + +You have two options: + +#### Option A: Use the Helper Script (Recommended) + +```bash +python examples/auth/oauth_token_helper.py --provider github +``` + +The script will: +1. Generate an authorization URL +2. Open your browser to authorize +3. Start a local callback server +4. Automatically extract the access token +5. Show you the environment variables to set + +#### Option B: Manual Token Generation + +1. **Generate Authorization URL**: + ```bash + python -c " + from nextmcp.auth import GitHubOAuthProvider + provider = GitHubOAuthProvider( + client_id='YOUR_CLIENT_ID', + client_secret='YOUR_CLIENT_SECRET', + scope=['read:user', 'repo'] + ) + auth_data = provider.generate_authorization_url() + print(f'Visit: {auth_data[\"url\"]}') + print(f'Verifier: {auth_data[\"verifier\"]}') + " + ``` + +2. **Visit the URL** in your browser and click "Authorize" + +3. **Copy the code** from the callback URL: + ``` + http://localhost:8080/oauth/callback?code=AUTHORIZATION_CODE&state=... + ``` + +4. **Exchange code for token**: + ```bash + python -c " + import asyncio + from nextmcp.auth import GitHubOAuthProvider + + async def get_token(): + provider = GitHubOAuthProvider( + client_id='YOUR_CLIENT_ID', + client_secret='YOUR_CLIENT_SECRET' + ) + token_data = await provider.exchange_code_for_token( + code='AUTHORIZATION_CODE', + state='STATE_FROM_URL', + verifier='VERIFIER_FROM_STEP_1' + ) + print(f'Access Token: {token_data[\"access_token\"]}') + + asyncio.run(get_token()) + " + ``` + +5. **Set environment variable**: + ```bash + export GITHUB_ACCESS_TOKEN="gho_xxxxxxxxxxxxxxxxxxxxx" + ``` + +### Step 3: Configure Environment + +```bash +# Add to your ~/.bashrc or ~/.zshrc +export GITHUB_CLIENT_ID="your_client_id_here" +export GITHUB_CLIENT_SECRET="your_client_secret_here" +export GITHUB_ACCESS_TOKEN="gho_your_access_token_here" +``` + +Or create a `.env` file: +```bash +# .env +GITHUB_CLIENT_ID=your_client_id_here +GITHUB_CLIENT_SECRET=your_client_secret_here +GITHUB_ACCESS_TOKEN=gho_your_access_token_here +``` + +Then load it: +```bash +export $(cat .env | xargs) +``` + +--- + +## Google OAuth Setup + +### Step 1: Create a Google Cloud Project + +1. Go to **Google Cloud Console**: https://console.cloud.google.com + +2. **Create a new project**: + - Click the project dropdown at the top + - Click "New Project" + - Name: "NextMCP OAuth Testing" + - Click "Create" + +3. **Enable APIs**: + - Go to "APIs & Services" → "Library" + - Search for and enable: + - Google Drive API + - Gmail API + - Google+ API (for userinfo) + +### Step 2: Create OAuth 2.0 Credentials + +1. Go to **"APIs & Services"** → **"Credentials"** + +2. Click **"Create Credentials"** → **"OAuth client ID"** + +3. If prompted, configure the OAuth consent screen: + - User Type: **External** + - App name: **NextMCP OAuth Testing** + - User support email: Your email + - Developer contact: Your email + - Scopes: Add these scopes: + - `.../auth/userinfo.email` + - `.../auth/userinfo.profile` + - `.../auth/drive.readonly` + - `.../auth/gmail.readonly` + - Test users: Add your email address + +4. Create OAuth Client ID: + - Application type: **Web application** + - Name: **NextMCP OAuth Testing** + - Authorized redirect URIs: + - `http://localhost:8080/oauth/callback` + - Click "Create" + +5. **Download credentials** or copy: + - Client ID (ends in `.apps.googleusercontent.com`) + - Client Secret + +### Step 3: Get Google Access Token + +#### Option A: Use the Helper Script (Recommended) + +```bash +python examples/auth/oauth_token_helper.py --provider google +``` + +The script will: +1. Generate an authorization URL with offline access +2. Open your browser to authorize +3. Start a local callback server +4. Extract access token AND refresh token +5. Show you the environment variables to set + +#### Option B: Manual Token Generation + +1. **Generate Authorization URL**: + ```bash + python -c " + from nextmcp.auth import GoogleOAuthProvider + provider = GoogleOAuthProvider( + client_id='YOUR_CLIENT_ID.apps.googleusercontent.com', + client_secret='YOUR_CLIENT_SECRET', + scope=[ + 'https://www.googleapis.com/auth/userinfo.profile', + 'https://www.googleapis.com/auth/userinfo.email', + 'https://www.googleapis.com/auth/drive.readonly' + ] + ) + auth_data = provider.generate_authorization_url() + print(f'Visit: {auth_data[\"url\"]}') + print(f'Verifier: {auth_data[\"verifier\"]}') + " + ``` + +2. **Visit the URL**, sign in, and authorize the app + +3. **Copy the code** from the callback URL + +4. **Exchange code for tokens**: + ```bash + python -c " + import asyncio + from nextmcp.auth import GoogleOAuthProvider + + async def get_token(): + provider = GoogleOAuthProvider( + client_id='YOUR_CLIENT_ID', + client_secret='YOUR_CLIENT_SECRET' + ) + token_data = await provider.exchange_code_for_token( + code='AUTHORIZATION_CODE', + state='STATE_FROM_URL', + verifier='VERIFIER_FROM_STEP_1' + ) + print(f'Access Token: {token_data[\"access_token\"]}') + print(f'Refresh Token: {token_data.get(\"refresh_token\", \"N/A\")}') + + asyncio.run(get_token()) + " + ``` + +5. **Set environment variables**: + ```bash + export GOOGLE_ACCESS_TOKEN="ya29.xxxxxxxxxxxxx" + export GOOGLE_REFRESH_TOKEN="1//xxxxxxxxxxxxx" # If provided + ``` + +### Step 4: Configure Environment + +```bash +# Add to your ~/.bashrc or ~/.zshrc +export GOOGLE_CLIENT_ID="your_client_id.apps.googleusercontent.com" +export GOOGLE_CLIENT_SECRET="your_client_secret" +export GOOGLE_ACCESS_TOKEN="ya29.your_access_token" +export GOOGLE_REFRESH_TOKEN="1//your_refresh_token" +``` + +Or create a `.env` file: +```bash +# .env +GOOGLE_CLIENT_ID=your_client_id.apps.googleusercontent.com +GOOGLE_CLIENT_SECRET=your_client_secret +GOOGLE_ACCESS_TOKEN=ya29.your_access_token +GOOGLE_REFRESH_TOKEN=1//your_refresh_token +``` + +--- + +## Running the Tests + +### Run All Integration Tests + +```bash +# Activate virtual environment +source .venv/bin/activate + +# Run integration tests with verbose output +pytest tests/test_oauth_integration.py -v -m integration +``` + +### Run Specific Provider Tests + +```bash +# GitHub only +pytest tests/test_oauth_integration.py::TestGitHubOAuthIntegration -v + +# Google only +pytest tests/test_oauth_integration.py::TestGoogleOAuthIntegration -v +``` + +### Run Specific Test + +```bash +# Test GitHub user info retrieval +pytest tests/test_oauth_integration.py::TestGitHubOAuthIntegration::test_github_get_user_info -v + +# Test Google token refresh +pytest tests/test_oauth_integration.py::TestGoogleOAuthIntegration::test_google_token_refresh -v +``` + +### Skip Integration Tests (Default) + +```bash +# Regular test run automatically skips integration tests +pytest + +# Or explicitly skip them +pytest -m "not integration" +``` + +--- + +## Troubleshooting + +### "Tests skipped" message + +This means the required environment variables are not set. Check: + +```bash +# Verify environment variables are set +echo $GITHUB_CLIENT_ID +echo $GITHUB_ACCESS_TOKEN +echo $GOOGLE_CLIENT_ID +echo $GOOGLE_ACCESS_TOKEN + +# If empty, source your environment file +source ~/.bashrc # or ~/.zshrc +# or +export $(cat .env | xargs) +``` + +### "Invalid token" errors + +Access tokens expire! GitHub tokens last indefinitely (until revoked), but Google access tokens expire after 1 hour. + +**Solution**: Re-run the helper script to get a fresh token: +```bash +python examples/auth/oauth_token_helper.py --provider google +``` + +For Google, use the refresh token to get a new access token: +```bash +pytest tests/test_oauth_integration.py::TestGoogleOAuthIntegration::test_google_token_refresh -v -s +# Copy the new access token from the output +``` + +### "Redirect URI mismatch" errors + +Make sure your OAuth app has `http://localhost:8080/oauth/callback` as an authorized redirect URI. + +### Google "Access blocked: Authorization Error" + +Your app is in testing mode. Add your Google account as a test user: +1. Go to Google Cloud Console +2. APIs & Services → OAuth consent screen +3. Scroll to "Test users" +4. Click "Add Users" +5. Add your email address + +### Rate limiting + +OAuth APIs have rate limits. If you hit them: +- **GitHub**: Wait a bit or use a different account +- **Google**: Wait for the quota to reset (usually hourly) + +--- + +## Security Best Practices + +⚠️ **Never commit credentials to git!** + +Add to `.gitignore`: +``` +.env +.env.* +*_credentials.json +*_token.json +``` + +Use environment variables or a secure secrets manager in production. + +For testing, tokens with minimal scopes are recommended: +- **GitHub**: `read:user` is sufficient for basic tests +- **Google**: Use `userinfo.profile` and `userinfo.email` only + +--- + +## What Each Test Verifies + +### GitHub Tests + +1. **Authorization URL Generation** - Verifies PKCE challenge and URL formatting +2. **User Info Retrieval** - Tests GitHub API `/user` endpoint +3. **Authentication Flow** - Tests complete auth with access token +4. **Error Handling** - Verifies invalid tokens are rejected + +### Google Tests + +1. **Authorization URL Generation** - Verifies offline access parameters +2. **User Info Retrieval** - Tests Google userinfo API +3. **Authentication Flow** - Tests complete auth with access token +4. **Token Refresh** - Tests refresh token flow (unique to Google) +5. **Error Handling** - Verifies invalid tokens/refresh tokens are rejected + +--- + +## Next Steps + +Once you have integration tests passing, you can: + +1. **Build OAuth-protected tools** using the examples in `examples/auth/` +2. **Implement OAuth callback servers** for production use +3. **Add custom OAuth providers** by extending `OAuthProvider` +4. **Test with your own APIs** using the authenticated tokens + +For production deployments, see the examples in `examples/auth/` for complete OAuth flow implementations. diff --git a/examples/auth/combined_auth_server.py b/examples/auth/combined_auth_server.py new file mode 100644 index 0000000..56f0cdd --- /dev/null +++ b/examples/auth/combined_auth_server.py @@ -0,0 +1,608 @@ +""" +Combined Authentication & Authorization MCP Server Example. + +This comprehensive example demonstrates using all NextMCP auth features together: +- Multiple auth providers (API Key, JWT, OAuth) +- Role-Based Access Control (RBAC) +- Permission system +- OAuth scopes +- Permission manifests +- Error handling + +Features demonstrated: +- APIKeyProvider for service accounts +- GitHubOAuthProvider for user authentication +- RBAC with hierarchical roles +- Fine-grained permissions +- OAuth scope enforcement +- Manifest-based security +- Custom error handling + +Usage: + python examples/auth/combined_auth_server.py +""" + +import asyncio +import os +from typing import Any + +from nextmcp import NextMCP +from nextmcp.auth import ( + APIKeyProvider, + AuthContext, + GitHubOAuthProvider, + ManifestViolationError, + OAuthRequiredError, + Permission, + PermissionManifest, + RBAC, + Role, + ScopeInsufficientError, + requires_auth_async, + requires_manifest_async, + requires_permission_async, + requires_role_async, + requires_scope_async, +) + +# Initialize MCP server +mcp = NextMCP("Combined Auth Example") + +# ======================================================================== +# 1. SETUP RBAC SYSTEM +# ======================================================================== + +print("Setting up RBAC system...") + +rbac = RBAC() + +# Define roles +admin_role = rbac.define_role("admin", "Administrator with full access") +editor_role = rbac.define_role("editor", "Content editor") +viewer_role = rbac.define_role("viewer", "Read-only access") +service_role = rbac.define_role("service", "Service account for automation") + +# Define permissions +rbac.define_permission("*", "All permissions") +rbac.define_permission("read:*", "Read all") +rbac.define_permission("write:*", "Write all") +rbac.define_permission("automation:*", "Automation permissions") +rbac.define_permission("admin:all", "Admin all") +rbac.define_permission("read:posts", "Read posts") +rbac.define_permission("write:posts", "Write posts") +rbac.define_permission("read:pages", "Read pages") +rbac.define_permission("write:pages", "Write pages") +rbac.define_permission("admin:users", "Manage users") + +# Assign permissions to roles +rbac.assign_permission_to_role("admin", "*") +rbac.assign_permission_to_role("editor", "read:*") +rbac.assign_permission_to_role("editor", "write:posts") +rbac.assign_permission_to_role("editor", "write:pages") +rbac.assign_permission_to_role("viewer", "read:*") +rbac.assign_permission_to_role("service", "read:*") +rbac.assign_permission_to_role("service", "write:*") +rbac.assign_permission_to_role("service", "automation:*") + +print(f"✓ Registered {len(rbac.list_roles())} roles\n") + +# ======================================================================== +# 2. SETUP AUTH PROVIDERS +# ======================================================================== + +print("Setting up authentication providers...") + +# API Key provider for service accounts and testing +api_key_provider = APIKeyProvider( + valid_keys={ + "admin_key_123": { + "user_id": "admin1", + "username": "Admin User", + "roles": ["admin"], + "permissions": ["admin:all"], + }, + "editor_key_456": { + "user_id": "editor1", + "username": "Editor User", + "roles": ["editor"], + "permissions": ["read:posts", "write:posts", "read:pages", "write:pages"], + }, + "viewer_key_789": { + "user_id": "viewer1", + "username": "Viewer User", + "roles": ["viewer"], + "permissions": ["read:posts", "read:pages"], + }, + "service_key_abc": { + "user_id": "service1", + "username": "Automation Service", + "roles": ["service"], + "permissions": ["automation:jobs", "read:all", "write:all"], + }, + } +) + +# GitHub OAuth provider for user authentication +github_oauth = GitHubOAuthProvider( + client_id=os.getenv("GITHUB_CLIENT_ID", "demo_client_id"), + client_secret=os.getenv("GITHUB_CLIENT_SECRET"), + scope=["read:user", "repo"], +) + +print("✓ Configured API Key provider") +print("✓ Configured GitHub OAuth provider\n") + +# ======================================================================== +# 3. SETUP PERMISSION MANIFEST +# ======================================================================== + +print("Setting up permission manifest...") + +manifest = PermissionManifest() + +# Define OAuth scope mappings +manifest.define_scope( + "repo:read", + "Read repository access", + {"github": ["repo", "public_repo"]}, +) + +manifest.define_scope( + "repo:write", + "Write repository access", + {"github": ["repo"]}, +) + +# Define tool requirements +manifest.define_tool_permission( + "view_content", + roles=["viewer", "editor", "admin"], + permissions=["read:posts", "read:pages"], + description="View content", +) + +manifest.define_tool_permission( + "edit_content", + roles=["editor", "admin"], + permissions=["write:posts", "write:pages"], + description="Edit content", +) + +manifest.define_tool_permission( + "manage_users", + roles=["admin"], + permissions=["admin:users"], + description="Manage user accounts", +) + +manifest.define_tool_permission( + "github_repo_access", + scopes=["repo:read"], + description="Access GitHub repositories via OAuth", +) + +manifest.define_tool_permission( + "dangerous_operation", + roles=["admin"], + permissions=["admin:all"], + scopes=["admin:full"], + description="Dangerous admin operation", + dangerous=True, +) + +print(f"✓ Configured {len(manifest.tools)} protected tools\n") + +# ======================================================================== +# 4. DEFINE TOOLS WITH DIFFERENT AUTH STRATEGIES +# ======================================================================== + +# --- Public Tools (No Auth) --- + +@mcp.tool() +async def get_public_info() -> dict[str, Any]: + """ + Get public information. + + No authentication required. + + Returns: + Public information + """ + return { + "service": "Combined Auth MCP Server", + "version": "1.0.0", + "auth_methods": ["api_key", "github_oauth"], + "public": True, + } + + +@mcp.tool() +async def get_github_auth_url(state: str | None = None) -> dict[str, str]: + """ + Get GitHub OAuth authorization URL. + + No authentication required to get the URL. + + Args: + state: Optional CSRF protection state + + Returns: + Authorization URL and PKCE data + """ + return github_oauth.generate_authorization_url(state=state) + + +# --- Basic Auth Required --- + +@mcp.tool() +@requires_auth_async(provider=api_key_provider) +async def get_my_profile(auth: AuthContext) -> dict[str, Any]: + """ + Get authenticated user's profile. + + Requires: Any valid authentication + + Args: + auth: Authentication context (injected) + + Returns: + User profile + """ + return { + "user_id": auth.user_id, + "username": auth.username, + "roles": [r.name for r in auth.roles], + "permissions": [p.name for p in auth.permissions], + "scopes": list(auth.scopes), + } + + +# --- Role-Based Access --- + +@mcp.tool() +@requires_auth_async(provider=api_key_provider) +@requires_role_async("viewer", "editor", "admin") +async def view_content(auth: AuthContext, content_id: str) -> dict[str, Any]: + """ + View content. + + Requires: viewer, editor, or admin role + + Args: + auth: Authentication context (injected) + content_id: Content ID to view + + Returns: + Content data + """ + return { + "content_id": content_id, + "title": "Sample Article", + "body": "This is sample content...", + "author": "system", + "viewed_by": auth.username, + } + + +@mcp.tool() +@requires_auth_async(provider=api_key_provider) +@requires_role_async("editor", "admin") +async def edit_content( + auth: AuthContext, + content_id: str, + new_content: str, +) -> dict[str, Any]: + """ + Edit content. + + Requires: editor or admin role + + Args: + auth: Authentication context (injected) + content_id: Content ID to edit + new_content: New content + + Returns: + Update result + """ + return { + "success": True, + "content_id": content_id, + "updated_by": auth.username, + "message": "Content updated successfully", + } + + +@mcp.tool() +@requires_auth_async(provider=api_key_provider) +@requires_role_async("admin") +async def manage_users( + auth: AuthContext, + action: str, + user_id: str, +) -> dict[str, Any]: + """ + Manage user accounts. + + Requires: admin role + + Args: + auth: Authentication context (injected) + action: Action to perform (create, delete, modify) + user_id: Target user ID + + Returns: + Action result + """ + return { + "success": True, + "action": action, + "target_user": user_id, + "performed_by": auth.username, + "message": f"User {action} completed", + } + + +# --- Permission-Based Access --- + +@mcp.tool() +@requires_auth_async(provider=api_key_provider) +@requires_permission_async("write:posts") +async def create_post( + auth: AuthContext, + title: str, + content: str, +) -> dict[str, Any]: + """ + Create a new post. + + Requires: write:posts permission + + Args: + auth: Authentication context (injected) + title: Post title + content: Post content + + Returns: + Created post + """ + return { + "success": True, + "post_id": "post123", + "title": title, + "author": auth.username, + "message": "Post created successfully", + } + + +# --- Manifest-Based Access --- + +@mcp.tool() +@requires_auth_async(provider=api_key_provider) +@requires_manifest_async(manifest=manifest) +async def dangerous_operation( + auth: AuthContext, + confirmation: str, +) -> dict[str, Any]: + """ + Perform dangerous admin operation. + + Requires: admin role + admin:all permission + admin:full scope + (enforced by manifest, marked as dangerous) + + Args: + auth: Authentication context (injected) + confirmation: Confirmation string + + Returns: + Operation result + """ + if confirmation != "I CONFIRM": + return {"success": False, "error": "Invalid confirmation"} + + return { + "success": True, + "performed_by": auth.username, + "warning": "Dangerous operation completed", + } + + +# --- OAuth Scope-Based Access (if using GitHub OAuth) --- + +# Note: This would use github_oauth provider instead of api_key_provider +# Shown as example - would need actual OAuth flow to use + +@mcp.tool() +async def example_github_tool_info() -> dict[str, str]: + """ + Example info about GitHub OAuth tool. + + This demonstrates how to use GitHub OAuth. + + Returns: + Example information + """ + return { + "note": "To use GitHub OAuth tools:", + "step1": "Call get_github_auth_url() to get authorization URL", + "step2": "User authorizes at that URL", + "step3": "Exchange code for access token", + "step4": "Use access token in auth parameter", + "example_tool": "list_user_repos (requires repo scope)", + } + + +# ======================================================================== +# 5. ERROR HANDLING DEMONSTRATION +# ======================================================================== + +async def demonstrate_errors(): + """Demonstrate specialized error handling.""" + print("\n" + "=" * 60) + print("ERROR HANDLING DEMONSTRATION") + print("=" * 60 + "\n") + + # 1. OAuthRequiredError + print("1. OAuthRequiredError - when OAuth is needed:") + try: + raise OAuthRequiredError( + "GitHub OAuth required for this operation", + provider="github", + scopes=["repo"], + authorization_url=github_oauth.generate_authorization_url()["url"], + ) + except OAuthRequiredError as e: + print(f" Error: {e}") + print(f" Provider: {e.provider}") + print(f" Required scopes: {e.scopes}") + print(f" Auth URL: {e.authorization_url[:50]}...") + print() + + # 2. ScopeInsufficientError + print("2. ScopeInsufficientError - when user lacks scopes:") + try: + raise ScopeInsufficientError( + "Insufficient OAuth scopes", + required_scopes=["repo:write"], + current_scopes=["repo:read"], + user_id="user123", + ) + except ScopeInsufficientError as e: + print(f" Error: {e}") + print(f" Required: {e.required_scopes}") + print(f" Current: {e.current_scopes}") + print(f" User: {e.user_id}") + print() + + # 3. ManifestViolationError + print("3. ManifestViolationError - when manifest check fails:") + try: + raise ManifestViolationError( + "Access denied by manifest", + tool_name="dangerous_operation", + required_roles=["admin"], + required_permissions=["admin:all"], + required_scopes=["admin:full"], + user_id="user123", + ) + except ManifestViolationError as e: + print(f" Error: {e}") + print(f" Tool: {e.tool_name}") + print(f" Required roles: {e.required_roles}") + print(f" Required permissions: {e.required_permissions}") + print(f" Required scopes: {e.required_scopes}") + print() + + +# ======================================================================== +# 6. COMPREHENSIVE DEMONSTRATION +# ======================================================================== + +async def demonstrate_all_features(): + """Demonstrate all auth features.""" + print("\n" + "=" * 60) + print("COMPREHENSIVE AUTH DEMONSTRATION") + print("=" * 60 + "\n") + + # Test 1: Public access + print("Test 1: Public access (no auth required)") + result = await get_public_info() + print(f"✓ {result['service']}\n") + + # Test 2: Basic auth + print("Test 2: Basic auth (API key)") + result = await get_my_profile(auth={"api_key": "viewer_key_789"}) + print(f"✓ Authenticated as {result['username']}") + print(f" Roles: {', '.join(result['roles'])}\n") + + # Test 3: Role-based access (allowed) + print("Test 3: Role-based access - viewer viewing content (allowed)") + try: + result = await view_content( + auth={"api_key": "viewer_key_789"}, + content_id="article1", + ) + print(f"✓ {result['viewed_by']} viewed content\n") + except Exception as e: + print(f"✗ {e}\n") + + # Test 4: Role-based access (denied) + print("Test 4: Role-based access - viewer editing content (denied)") + try: + result = await edit_content( + auth={"api_key": "viewer_key_789"}, + content_id="article1", + new_content="Updated", + ) + print(f"✓ Unexpected success\n") + except Exception as e: + print(f"✓ Correctly denied: Permission denied\n") + + # Test 5: Editor can edit + print("Test 5: Role-based access - editor editing content (allowed)") + try: + result = await edit_content( + auth={"api_key": "editor_key_456"}, + content_id="article1", + new_content="Updated content", + ) + print(f"✓ {result['updated_by']} edited content\n") + except Exception as e: + print(f"✗ {e}\n") + + # Test 6: Admin access + print("Test 6: Admin-only access - managing users (allowed)") + try: + result = await manage_users( + auth={"api_key": "admin_key_123"}, + action="create", + user_id="newuser1", + ) + print(f"✓ {result['performed_by']} managed users\n") + except Exception as e: + print(f"✗ {e}\n") + + # Test 7: Permission-based access + print("Test 7: Permission-based access - creating post (allowed)") + try: + result = await create_post( + auth={"api_key": "editor_key_456"}, + title="New Post", + content="Post content", + ) + print(f"✓ {result['author']} created post\n") + except Exception as e: + print(f"✗ {e}\n") + + # Error handling demonstration + await demonstrate_errors() + + # Summary + print("=" * 60) + print("FEATURE SUMMARY") + print("=" * 60) + print(f"Auth Providers: API Key, GitHub OAuth") + print(f"Roles: {', '.join(r.name for r in rbac.list_roles())}") + print(f"Protected Tools: {len(manifest.tools)}") + print(f"Error Types: 3 specialized exceptions") + print("=" * 60) + + +if __name__ == "__main__": + print("Starting Combined Auth MCP Server...\n") + + # Run comprehensive demonstration + asyncio.run(demonstrate_all_features()) + + print("\n" + "=" * 60) + print("KEY FEATURES DEMONSTRATED") + print("=" * 60) + print("1. Multiple Auth Providers (API Key, OAuth)") + print("2. Role-Based Access Control (RBAC)") + print("3. Fine-Grained Permissions") + print("4. OAuth Scopes") + print("5. Permission Manifests") + print("6. Specialized Error Types") + print("7. Declarative Security") + print("8. Middleware Decorators") + print("=" * 60) diff --git a/examples/auth/complete_oauth_server.py b/examples/auth/complete_oauth_server.py new file mode 100644 index 0000000..06c8e9c --- /dev/null +++ b/examples/auth/complete_oauth_server.py @@ -0,0 +1,347 @@ +""" +Complete OAuth Server Example + +This example demonstrates a production-ready MCP server with: +- Google OAuth authentication +- Session management with file storage +- Automatic token refresh +- Auth metadata exposure +- Protected and public tools +- Comprehensive error handling + +Run this server: + python examples/auth/complete_oauth_server.py + +Then use examples/auth/oauth_token_helper.py to get tokens and test. +""" + +import asyncio +import os +from pathlib import Path + +from fastmcp import FastMCP + +from nextmcp.auth import ( + GoogleOAuthProvider, + create_auth_middleware, +) +from nextmcp.protocol import ( + AuthFlowType, + AuthMetadata, + AuthRequirement, +) +from nextmcp.session import FileSessionStore + +# ============================================================================ +# Configuration +# ============================================================================ + +# Get OAuth credentials from environment +GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID", "your-client-id") +GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET", "your-client-secret") + +# Session storage directory +SESSION_DIR = Path(".nextmcp_sessions") + + +# ============================================================================ +# Initialize Server +# ============================================================================ + +mcp = FastMCP("Complete OAuth Server") + + +# ============================================================================ +# Set Up OAuth Provider +# ============================================================================ + +google = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + redirect_uri="http://localhost:8080/oauth/callback", + scope=["openid", "email", "profile", "https://www.googleapis.com/auth/drive.readonly"], +) + + +# ============================================================================ +# Set Up Session Store +# ============================================================================ + +session_store = FileSessionStore(SESSION_DIR) + + +# ============================================================================ +# Create Auth Metadata +# ============================================================================ + +auth_metadata = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + supports_multi_user=True, + token_refresh_enabled=True, + session_management="server-side", +) + +# Add Google OAuth provider +auth_metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["openid", "email", "profile", "https://www.googleapis.com/auth/drive.readonly"], + supports_refresh=True, + supports_pkce=True, +) + +# Add required scopes +auth_metadata.add_required_scope("openid") +auth_metadata.add_required_scope("email") + +# Add optional scopes +auth_metadata.add_optional_scope("profile") +auth_metadata.add_optional_scope("https://www.googleapis.com/auth/drive.readonly") + +# Add error code documentation +auth_metadata.error_codes = { + "authentication_required": "You must be authenticated to access this server", + "authorization_denied": "You lack the required permissions to access this resource", + "token_expired": "Your access token has expired - please refresh or re-authenticate", + "insufficient_scopes": "Additional OAuth scopes are required for this operation", +} + + +# ============================================================================ +# Apply Auth Middleware +# ============================================================================ + +auth_middleware = create_auth_middleware( + provider=google, + requirement=AuthRequirement.REQUIRED, + session_store=session_store, + required_scopes=["openid", "email"], +) + +mcp.use(auth_middleware) + + +# ============================================================================ +# Public Tools (No special auth needed - middleware handles it) +# ============================================================================ + + +@mcp.tool() +def get_auth_metadata() -> dict: + """ + Get server authentication requirements. + + This tool returns information about what authentication is required, + which OAuth providers are supported, and what permissions are needed. + + Returns: + dict: Complete authentication metadata + """ + return auth_metadata.to_dict() + + +@mcp.tool() +def get_server_info() -> dict: + """ + Get server information. + + Returns basic information about this MCP server. + + Returns: + dict: Server name, version, and capabilities + """ + return { + "name": "Complete OAuth Server", + "version": "1.0.0", + "auth_enabled": True, + "session_storage": "file", + "features": [ + "OAuth 2.0 with PKCE", + "Session management", + "Token refresh", + "Multi-user support", + ], + } + + +# ============================================================================ +# Protected Tools (Require authentication) +# ============================================================================ + + +@mcp.tool() +def get_user_profile() -> dict: + """ + Get authenticated user's profile. + + Returns information about the currently authenticated user from their + OAuth provider profile. + + Returns: + dict: User's email, name, and other profile information + """ + # Note: In a real implementation, you would access the auth context + # from the request to get user info. For now, this is a placeholder. + return { + "message": "This would return the authenticated user's profile", + "note": "Access auth context via request['_auth_context']", + } + + +@mcp.tool() +def list_user_files() -> dict: + """ + List user's Google Drive files. + + Lists files from the authenticated user's Google Drive using their + OAuth access token. + + Returns: + dict: List of files with names and IDs + + Raises: + AuthorizationError: If user hasn't granted drive.readonly scope + """ + return { + "message": "This would list files from user's Google Drive", + "required_scope": "https://www.googleapis.com/auth/drive.readonly", + "note": "Actual implementation would use access_token to call Drive API", + } + + +@mcp.tool() +def create_personalized_content(topic: str) -> str: + """ + Create personalized content for the authenticated user. + + Generates content tailored to the authenticated user based on their + profile and the requested topic. + + Args: + topic: The topic to generate content about + + Returns: + str: Personalized content + + Example: + >>> create_personalized_content("machine learning") + "Hello John! Here's ML content personalized for you..." + """ + return f"Creating personalized content about {topic} for authenticated user..." + + +# ============================================================================ +# Session Management Tools +# ============================================================================ + + +@mcp.tool() +def get_active_sessions() -> dict: + """ + Get list of active user sessions. + + Returns information about all currently active sessions stored on + the server. + + Returns: + dict: Number of active sessions and their user IDs + """ + users = session_store.list_users() + + return { + "active_sessions": len(users), + "users": users, + "storage_type": "file", + "storage_location": str(SESSION_DIR.absolute()), + } + + +@mcp.tool() +def cleanup_expired_sessions() -> dict: + """ + Clean up expired sessions. + + Removes all sessions with expired access tokens from storage. + + Returns: + dict: Number of sessions cleaned up + """ + cleaned = session_store.cleanup_expired() + + return { + "cleaned_sessions": cleaned, + "message": f"Removed {cleaned} expired session(s)", + } + + +# ============================================================================ +# Utility Functions +# ============================================================================ + + +def print_startup_message(): + """Print helpful startup information.""" + print("\n" + "=" * 70) + print("🚀 Complete OAuth Server Started") + print("=" * 70) + print("\nConfiguration:") + print(f" OAuth Provider: Google") + print(f" Client ID: {GOOGLE_CLIENT_ID}") + print(f" Session Storage: {SESSION_DIR.absolute()}") + print(f" Auth Required: Yes") + print(f" Token Refresh: Enabled") + print("\nAvailable Tools:") + print(" 📋 Public:") + print(" - get_auth_metadata: Get auth requirements") + print(" - get_server_info: Get server information") + print("\n 🔐 Protected (requires authentication):") + print(" - get_user_profile: Get user's profile") + print(" - list_user_files: List Google Drive files") + print(" - create_personalized_content: Generate personalized content") + print(" - get_active_sessions: View active sessions") + print(" - cleanup_expired_sessions: Remove expired sessions") + print("\nTo get OAuth tokens:") + print(f" 1. Set environment variables:") + print(f" export GOOGLE_CLIENT_ID='{GOOGLE_CLIENT_ID}'") + print(f" export GOOGLE_CLIENT_SECRET='{GOOGLE_CLIENT_SECRET}'") + print("\n 2. Run token helper:") + print(" python examples/auth/oauth_token_helper.py --provider google") + print("\n 3. Test authenticated requests:") + print(" Use the access token from step 2 in your MCP client") + print("\n" + "=" * 70 + "\n") + + +async def periodic_cleanup(): + """Periodically clean up expired sessions.""" + while True: + await asyncio.sleep(3600) # Every hour + cleaned = session_store.cleanup_expired() + if cleaned > 0: + print(f"[Cleanup] Removed {cleaned} expired session(s)") + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + + +def main(): + """Run the server.""" + # Print startup information + print_startup_message() + + # Ensure session directory exists + SESSION_DIR.mkdir(parents=True, exist_ok=True) + + # Start periodic cleanup task + # asyncio.create_task(periodic_cleanup()) + + # Run server + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/examples/auth/github_oauth_server.py b/examples/auth/github_oauth_server.py new file mode 100644 index 0000000..66a7da9 --- /dev/null +++ b/examples/auth/github_oauth_server.py @@ -0,0 +1,330 @@ +""" +GitHub OAuth MCP Server Example. + +This example demonstrates how to build an MCP server with GitHub OAuth authentication. +Users authenticate via GitHub OAuth 2.0 and can access tools based on their OAuth scopes. + +Features demonstrated: +- GitHub OAuth 2.0 with PKCE +- Scope-based access control +- User repository access +- Authorization URL generation + +Setup: +1. Create a GitHub OAuth App at https://github.com/settings/developers +2. Set redirect URI to: http://localhost:8080/oauth/callback +3. Copy Client ID and Client Secret +4. Set environment variables: + export GITHUB_CLIENT_ID="your_client_id" + export GITHUB_CLIENT_SECRET="your_client_secret" # Optional for PKCE + +Usage: + python examples/auth/github_oauth_server.py +""" + +import asyncio +import os +from typing import Any + +from nextmcp import NextMCP +from nextmcp.auth import ( + AuthContext, + GitHubOAuthProvider, + requires_auth_async, + requires_scope_async, +) + +# Initialize MCP server +mcp = NextMCP("GitHub OAuth Example") + +# Configure GitHub OAuth provider +github_oauth = GitHubOAuthProvider( + client_id=os.getenv("GITHUB_CLIENT_ID", "your_github_client_id"), + client_secret=os.getenv("GITHUB_CLIENT_SECRET"), # Optional with PKCE + redirect_uri="http://localhost:8080/oauth/callback", + scope=["read:user", "repo"], # Requested scopes +) + + +@mcp.tool() +async def get_authorization_url(state: str | None = None) -> dict[str, str]: + """ + Get GitHub OAuth authorization URL. + + This tool generates the URL users should visit to authorize the app. + No authentication required to call this tool. + + Args: + state: Optional state parameter for CSRF protection + + Returns: + Dict with 'url', 'state', and 'verifier' (store verifier securely!) + + Example: + result = await get_authorization_url() + # Send user to result['url'] + # Store result['verifier'] for later token exchange + """ + return github_oauth.generate_authorization_url(state=state) + + +@mcp.tool() +@requires_auth_async(provider=github_oauth) +async def get_my_profile(auth: AuthContext) -> dict[str, Any]: + """ + Get the authenticated user's GitHub profile. + + Requires OAuth authentication with 'read:user' scope. + + Args: + auth: Authentication context (injected by middleware) + + Returns: + User profile information from GitHub + + Example: + # After OAuth flow completes with access token + profile = await get_my_profile(auth={ + "access_token": "gho_...", + "scopes": ["read:user", "repo"] + }) + """ + # Access token is available in auth.metadata + access_token = auth.metadata.get("access_token") + + # Get user info from GitHub + user_info = auth.metadata.get("user_info", {}) + + return { + "user_id": auth.user_id, + "username": auth.username, + "name": user_info.get("name"), + "email": user_info.get("email"), + "bio": user_info.get("bio"), + "company": user_info.get("company"), + "location": user_info.get("location"), + "scopes": list(auth.scopes), + } + + +@mcp.tool() +@requires_auth_async(provider=github_oauth) +@requires_scope_async("repo") +async def list_my_repositories( + auth: AuthContext, + visibility: str = "all", + sort: str = "updated", +) -> dict[str, Any]: + """ + List the authenticated user's repositories. + + Requires OAuth authentication with 'repo' scope. + + Args: + auth: Authentication context (injected by middleware) + visibility: Repository visibility filter (all, public, private) + sort: Sort order (created, updated, pushed, full_name) + + Returns: + List of user's repositories + + Example: + repos = await list_my_repositories( + auth={"access_token": "gho_...", "scopes": ["repo"]}, + visibility="public", + sort="updated" + ) + """ + import aiohttp + + access_token = auth.metadata.get("access_token") + + # Call GitHub API + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + params = { + "visibility": visibility, + "sort": sort, + "per_page": 10, # Limit for example + } + + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.github.com/user/repos", + headers=headers, + params=params, + ) as resp: + if resp.status == 200: + repos = await resp.json() + return { + "count": len(repos), + "repositories": [ + { + "name": repo["name"], + "full_name": repo["full_name"], + "description": repo["description"], + "private": repo["private"], + "url": repo["html_url"], + "stars": repo["stargazers_count"], + "language": repo["language"], + "updated_at": repo["updated_at"], + } + for repo in repos + ], + } + else: + error_data = await resp.json() + return {"error": f"GitHub API error: {error_data}"} + + +@mcp.tool() +@requires_auth_async(provider=github_oauth) +@requires_scope_async("repo") +async def create_repository( + auth: AuthContext, + name: str, + description: str = "", + private: bool = False, +) -> dict[str, Any]: + """ + Create a new GitHub repository. + + Requires OAuth authentication with 'repo' scope. + + Args: + auth: Authentication context (injected by middleware) + name: Repository name + description: Repository description + private: Whether repository should be private + + Returns: + Created repository information + + Example: + repo = await create_repository( + auth={"access_token": "gho_...", "scopes": ["repo"]}, + name="my-new-repo", + description="Created via MCP", + private=False + ) + """ + import aiohttp + + access_token = auth.metadata.get("access_token") + + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + data = { + "name": name, + "description": description, + "private": private, + "auto_init": True, # Initialize with README + } + + async with aiohttp.ClientSession() as session: + async with session.post( + "https://api.github.com/user/repos", + headers=headers, + json=data, + ) as resp: + if resp.status == 201: + repo = await resp.json() + return { + "success": True, + "repository": { + "name": repo["name"], + "full_name": repo["full_name"], + "url": repo["html_url"], + "clone_url": repo["clone_url"], + "private": repo["private"], + }, + } + else: + error_data = await resp.json() + return { + "success": False, + "error": f"Failed to create repository: {error_data}", + } + + +# OAuth Flow Example +async def example_oauth_flow(): + """ + Example of complete OAuth flow. + + This demonstrates the full OAuth 2.0 authorization code flow with PKCE. + """ + print("=== GitHub OAuth Flow Example ===\n") + + # Step 1: Generate authorization URL + print("Step 1: Generating authorization URL...") + auth_data = github_oauth.generate_authorization_url() + print(f"Authorization URL: {auth_data['url']}") + print(f"State: {auth_data['state']}") + print(f"PKCE Verifier: {auth_data['verifier'][:20]}...") + print("\nUser should visit this URL and authorize the app.\n") + + # Step 2: After user authorizes, you receive a code + # (This is normally done via a web callback) + print("Step 2: User authorizes and you receive authorization code...") + print("(In real app, this comes from OAuth callback)\n") + + # Simulating received code + authorization_code = "simulated_code_from_github" + state_from_callback = auth_data["state"] + verifier = auth_data["verifier"] + + # Step 3: Exchange code for access token + print("Step 3: Exchanging code for access token...") + try: + # This would actually exchange the code (requires real code from GitHub) + # token_data = await github_oauth.exchange_code_for_token( + # code=authorization_code, + # state=state_from_callback, + # verifier=verifier + # ) + print("(Skipping actual exchange - requires real authorization code)") + print("Token data would contain: access_token, refresh_token, scope, etc.\n") + except Exception as e: + print(f"Note: {e}\n") + + # Step 4: Authenticate with access token + print("Step 4: Using access token to authenticate...") + print("(In real app, pass access_token to tool calls as auth credentials)") + print("\nExample tool call:") + print(' profile = await get_my_profile(auth={') + print(' "access_token": "gho_...",') + print(' "scopes": ["read:user", "repo"]') + print(' })\n') + + +if __name__ == "__main__": + # Run the server + print("Starting GitHub OAuth MCP Server...") + print("\nAvailable tools:") + print(" - get_authorization_url(): Get OAuth URL") + print(" - get_my_profile(): Get authenticated user's profile") + print(" - list_my_repositories(): List user's repositories") + print(" - create_repository(): Create a new repository") + print("\nRunning OAuth flow example...\n") + + # Show OAuth flow example + asyncio.run(example_oauth_flow()) + + print("\n" + "=" * 60) + print("To use this server with MCP:") + print("1. Set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET env vars") + print("2. Run: mcp run examples/auth/github_oauth_server.py") + print("3. Call get_authorization_url() to get OAuth URL") + print("4. Have user authorize at that URL") + print("5. Exchange code for token (handle OAuth callback)") + print("6. Use access token in subsequent tool calls") + print("=" * 60) diff --git a/examples/auth/google_oauth_server.py b/examples/auth/google_oauth_server.py new file mode 100644 index 0000000..b0a2c0a --- /dev/null +++ b/examples/auth/google_oauth_server.py @@ -0,0 +1,394 @@ +""" +Google OAuth MCP Server Example. + +This example demonstrates how to build an MCP server with Google OAuth authentication. +Users authenticate via Google OAuth 2.0 and can access tools based on their OAuth scopes. + +Features demonstrated: +- Google OAuth 2.0 with PKCE +- Offline access (refresh tokens) +- Scope-based access control +- Google Drive file access +- Gmail integration + +Setup: +1. Create a Google Cloud Project at https://console.cloud.google.com +2. Enable Google Drive API and Gmail API +3. Create OAuth 2.0 credentials (Web application) +4. Add authorized redirect URI: http://localhost:8080/oauth/callback +5. Download credentials and set environment variables: + export GOOGLE_CLIENT_ID="your_client_id.apps.googleusercontent.com" + export GOOGLE_CLIENT_SECRET="your_client_secret" + +Usage: + python examples/auth/google_oauth_server.py +""" + +import asyncio +import os +from typing import Any + +from nextmcp import NextMCP +from nextmcp.auth import ( + AuthContext, + GoogleOAuthProvider, + requires_auth_async, + requires_scope_async, +) + +# Initialize MCP server +mcp = NextMCP("Google OAuth Example") + +# Configure Google OAuth provider +google_oauth = GoogleOAuthProvider( + client_id=os.getenv("GOOGLE_CLIENT_ID", "your_google_client_id"), + client_secret=os.getenv("GOOGLE_CLIENT_SECRET", "your_google_client_secret"), + redirect_uri="http://localhost:8080/oauth/callback", + scope=[ + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/gmail.readonly", + ], +) + + +@mcp.tool() +async def get_authorization_url(state: str | None = None) -> dict[str, str]: + """ + Get Google OAuth authorization URL. + + This tool generates the URL users should visit to authorize the app. + Requests offline access for refresh tokens. + + Args: + state: Optional state parameter for CSRF protection + + Returns: + Dict with 'url', 'state', and 'verifier' (store verifier securely!) + + Example: + result = await get_authorization_url() + # Send user to result['url'] + # Store result['verifier'] for token exchange + """ + return google_oauth.generate_authorization_url(state=state) + + +@mcp.tool() +@requires_auth_async(provider=google_oauth) +async def get_my_profile(auth: AuthContext) -> dict[str, Any]: + """ + Get the authenticated user's Google profile. + + Requires OAuth authentication with profile and email scopes. + + Args: + auth: Authentication context (injected by middleware) + + Returns: + User profile information from Google + + Example: + profile = await get_my_profile(auth={ + "access_token": "ya29...", + "scopes": [ + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email" + ] + }) + """ + user_info = auth.metadata.get("user_info", {}) + + return { + "user_id": auth.user_id, + "email": auth.username, # GoogleOAuthProvider uses email as username + "name": user_info.get("name"), + "given_name": user_info.get("given_name"), + "family_name": user_info.get("family_name"), + "picture": user_info.get("picture"), + "locale": user_info.get("locale"), + "scopes": list(auth.scopes), + } + + +@mcp.tool() +@requires_auth_async(provider=google_oauth) +@requires_scope_async("https://www.googleapis.com/auth/drive.readonly") +async def list_drive_files( + auth: AuthContext, + page_size: int = 10, + query: str | None = None, +) -> dict[str, Any]: + """ + List files in user's Google Drive. + + Requires OAuth authentication with Drive read scope. + + Args: + auth: Authentication context (injected by middleware) + page_size: Number of files to return (max 100) + query: Optional search query (e.g., "name contains 'report'") + + Returns: + List of Drive files + + Example: + files = await list_drive_files( + auth={ + "access_token": "ya29...", + "scopes": ["https://www.googleapis.com/auth/drive.readonly"] + }, + page_size=10, + query="mimeType='application/pdf'" + ) + """ + import aiohttp + + access_token = auth.metadata.get("access_token") + + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + } + + params = { + "pageSize": min(page_size, 100), + "fields": "files(id,name,mimeType,createdTime,modifiedTime,size,webViewLink)", + } + + if query: + params["q"] = query + + async with aiohttp.ClientSession() as session: + async with session.get( + "https://www.googleapis.com/drive/v3/files", + headers=headers, + params=params, + ) as resp: + if resp.status == 200: + data = await resp.json() + files = data.get("files", []) + return { + "count": len(files), + "files": [ + { + "id": file["id"], + "name": file["name"], + "type": file["mimeType"], + "created": file.get("createdTime"), + "modified": file.get("modifiedTime"), + "size": file.get("size"), + "link": file.get("webViewLink"), + } + for file in files + ], + } + else: + error_data = await resp.json() + return {"error": f"Google Drive API error: {error_data}"} + + +@mcp.tool() +@requires_auth_async(provider=google_oauth) +@requires_scope_async("https://www.googleapis.com/auth/gmail.readonly") +async def list_gmail_messages( + auth: AuthContext, + max_results: int = 10, + query: str | None = None, +) -> dict[str, Any]: + """ + List messages in user's Gmail inbox. + + Requires OAuth authentication with Gmail read scope. + + Args: + auth: Authentication context (injected by middleware) + max_results: Number of messages to return + query: Optional Gmail search query (e.g., "is:unread") + + Returns: + List of Gmail messages + + Example: + messages = await list_gmail_messages( + auth={ + "access_token": "ya29...", + "scopes": ["https://www.googleapis.com/auth/gmail.readonly"] + }, + max_results=5, + query="is:unread" + ) + """ + import aiohttp + + access_token = auth.metadata.get("access_token") + + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + } + + params = { + "maxResults": max_results, + } + + if query: + params["q"] = query + + async with aiohttp.ClientSession() as session: + # Get message list + async with session.get( + "https://gmail.googleapis.com/gmail/v1/users/me/messages", + headers=headers, + params=params, + ) as resp: + if resp.status == 200: + data = await resp.json() + messages = data.get("messages", []) + + # Get details for each message + detailed_messages = [] + for msg in messages: + async with session.get( + f"https://gmail.googleapis.com/gmail/v1/users/me/messages/{msg['id']}", + headers=headers, + params={"format": "metadata", "metadataHeaders": ["From", "Subject", "Date"]}, + ) as detail_resp: + if detail_resp.status == 200: + detail = await detail_resp.json() + headers_dict = { + h["name"]: h["value"] for h in detail.get("payload", {}).get("headers", []) + } + detailed_messages.append({ + "id": detail["id"], + "from": headers_dict.get("From", "Unknown"), + "subject": headers_dict.get("Subject", "No Subject"), + "date": headers_dict.get("Date", "Unknown"), + "snippet": detail.get("snippet", ""), + }) + + return { + "count": len(detailed_messages), + "messages": detailed_messages, + } + else: + error_data = await resp.json() + return {"error": f"Gmail API error: {error_data}"} + + +@mcp.tool() +@requires_auth_async(provider=google_oauth) +async def refresh_access_token( + auth: AuthContext, + refresh_token: str, +) -> dict[str, Any]: + """ + Refresh an expired access token. + + Uses a refresh token to obtain a new access token. + Google OAuth provides refresh tokens with offline access. + + Args: + auth: Authentication context (injected by middleware) + refresh_token: The refresh token from initial OAuth flow + + Returns: + New access token and expiration info + + Example: + new_token = await refresh_access_token( + auth={ + "access_token": "old_token", # Can be expired + "scopes": [...] + }, + refresh_token="1//..." + ) + """ + try: + token_data = await google_oauth.refresh_access_token(refresh_token) + return { + "success": True, + "access_token": token_data.get("access_token"), + "expires_in": token_data.get("expires_in"), + "scope": token_data.get("scope"), + "token_type": token_data.get("token_type"), + } + except Exception as e: + return { + "success": False, + "error": str(e), + } + + +# OAuth Flow Example +async def example_oauth_flow(): + """ + Example of complete OAuth flow with Google. + + Demonstrates OAuth 2.0 authorization code flow with PKCE and refresh tokens. + """ + print("=== Google OAuth Flow Example ===\n") + + # Step 1: Generate authorization URL + print("Step 1: Generating authorization URL...") + auth_data = google_oauth.generate_authorization_url() + print(f"Authorization URL: {auth_data['url'][:80]}...") + print(f"State: {auth_data['state']}") + print(f"PKCE Verifier: {auth_data['verifier'][:20]}...") + print("\nUser should visit this URL and authorize the app.") + print("Note: URL includes access_type=offline for refresh tokens\n") + + # Step 2: After user authorizes, you receive a code + print("Step 2: User authorizes and you receive authorization code...") + print("(In real app, this comes from OAuth callback)\n") + + # Step 3: Exchange code for tokens + print("Step 3: Exchanging code for access and refresh tokens...") + print("(Skipping actual exchange - requires real authorization code)") + print("Token data would contain:") + print(" - access_token: For immediate API access") + print(" - refresh_token: For getting new access tokens") + print(" - expires_in: Token lifetime (typically 3600 seconds)") + print(" - scope: Granted scopes\n") + + # Step 4: Using tokens + print("Step 4: Using tokens...") + print("Access token is used for API calls:") + print(' files = await list_drive_files(auth={') + print(' "access_token": "ya29...",') + print(' "scopes": ["https://www.googleapis.com/auth/drive.readonly"]') + print(' })\n') + + print("When access token expires, use refresh token:") + print(' new_token = await refresh_access_token(') + print(' auth={"access_token": "old_token", "scopes": [...]},') + print(' refresh_token="1//..."') + print(' )\n') + + +if __name__ == "__main__": + # Run the server + print("Starting Google OAuth MCP Server...") + print("\nAvailable tools:") + print(" - get_authorization_url(): Get OAuth URL") + print(" - get_my_profile(): Get authenticated user's profile") + print(" - list_drive_files(): List Google Drive files") + print(" - list_gmail_messages(): List Gmail messages") + print(" - refresh_access_token(): Refresh expired tokens") + print("\nRunning OAuth flow example...\n") + + # Show OAuth flow example + asyncio.run(example_oauth_flow()) + + print("\n" + "=" * 60) + print("To use this server with MCP:") + print("1. Set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET env vars") + print("2. Enable Google Drive and Gmail APIs in Google Cloud Console") + print("3. Run: mcp run examples/auth/google_oauth_server.py") + print("4. Call get_authorization_url() to get OAuth URL") + print("5. Have user authorize at that URL") + print("6. Exchange code for tokens (handle OAuth callback)") + print("7. Use access token in subsequent tool calls") + print("8. Use refresh token when access token expires") + print("=" * 60) diff --git a/examples/auth/manifest_server.py b/examples/auth/manifest_server.py new file mode 100644 index 0000000..578d564 --- /dev/null +++ b/examples/auth/manifest_server.py @@ -0,0 +1,462 @@ +""" +Permission Manifest MCP Server Example. + +This example demonstrates how to use PermissionManifest for declarative security. +Define security policies in YAML and enforce them automatically via decorators. + +Features demonstrated: +- Loading manifests from YAML files +- Declarative security definitions +- Automatic manifest enforcement with @requires_manifest_async +- Role, permission, and scope requirements +- Dangerous operation flagging + +Setup: +1. Create a manifest YAML file (see manifest.yaml below) +2. Initialize manifest and load YAML +3. Apply @requires_manifest_async to protected tools + +Usage: + python examples/auth/manifest_server.py +""" + +import asyncio +import tempfile +from pathlib import Path +from typing import Any + +from nextmcp import NextMCP +from nextmcp.auth import ( + APIKeyProvider, + AuthContext, + Permission, + PermissionManifest, + Role, + requires_auth_async, + requires_manifest_async, +) + +# Initialize MCP server +mcp = NextMCP("Permission Manifest Example") + +# Create permission manifest +manifest = PermissionManifest() + +# Option 1: Define manifest programmatically +print("Defining manifest programmatically...") + +manifest.define_scope( + name="read:data", + description="Read access to data", + oauth_mapping={ + "github": ["repo:read"], + "google": ["drive.readonly"], + }, +) + +manifest.define_scope( + name="write:data", + description="Write access to data", + oauth_mapping={ + "github": ["repo:write"], + "google": ["drive.file"], + }, +) + +manifest.define_tool_permission( + tool_name="query_database", + roles=["viewer", "editor", "admin"], + permissions=["read:data"], + description="Query database for information", + dangerous=False, +) + +manifest.define_tool_permission( + tool_name="update_database", + roles=["editor", "admin"], + permissions=["write:data"], + description="Update database records", + dangerous=False, +) + +manifest.define_tool_permission( + tool_name="delete_all_data", + roles=["admin"], + permissions=["admin:all"], + scopes=["admin:full"], + description="Delete all data (DANGEROUS)", + dangerous=True, +) + +manifest.define_tool_permission( + tool_name="export_user_data", + roles=["admin", "data_analyst"], + permissions=["export:data", "read:data"], + description="Export user data for analysis", + dangerous=False, +) + +# Option 2: Load from YAML (demonstrate round-trip) +print("Exporting manifest to YAML...\n") + +yaml_content = """ +scopes: + - name: "read:data" + description: "Read access to data" + oauth_mapping: + github: + - "repo:read" + google: + - "drive.readonly" + + - name: "write:data" + description: "Write access to data" + oauth_mapping: + github: + - "repo:write" + google: + - "drive.file" + +tools: + query_database: + roles: + - "viewer" + - "editor" + - "admin" + permissions: + - "read:data" + description: "Query database for information" + dangerous: false + + update_database: + roles: + - "editor" + - "admin" + permissions: + - "write:data" + description: "Update database records" + dangerous: false + + delete_all_data: + roles: + - "admin" + permissions: + - "admin:all" + scopes: + - "admin:full" + description: "Delete all data (DANGEROUS)" + dangerous: true + + export_user_data: + roles: + - "admin" + - "data_analyst" + permissions: + - "export:data" + - "read:data" + description: "Export user data for analysis" + dangerous: false +""" + +print("Example manifest.yaml:") +print("=" * 60) +print(yaml_content) +print("=" * 60) +print() + +# Create auth provider with different user roles +print("Creating auth provider with test users...\n") + +auth_provider = APIKeyProvider( + valid_keys={ + "viewer_key": { + "user_id": "viewer_user", + "username": "Viewer User", + "roles": ["viewer"], + "permissions": ["read:data"], + }, + "editor_key": { + "user_id": "editor_user", + "username": "Editor User", + "roles": ["editor"], + "permissions": ["read:data", "write:data"], + }, + "admin_key": { + "user_id": "admin_user", + "username": "Admin User", + "roles": ["admin"], + "permissions": ["admin:all", "read:data", "write:data"], + "scopes": ["admin:full"], + }, + "analyst_key": { + "user_id": "analyst_user", + "username": "Data Analyst", + "roles": ["data_analyst"], + "permissions": ["read:data", "export:data"], + }, + } +) + + +# Define tools with manifest enforcement + +@mcp.tool() +@requires_auth_async(provider=auth_provider) +@requires_manifest_async(manifest=manifest) +async def query_database(auth: AuthContext, query: str) -> dict[str, Any]: + """ + Query the database. + + Requires: viewer, editor, or admin role + read:data permission + (as defined in manifest) + + Args: + auth: Authentication context (injected) + query: SQL-like query string + + Returns: + Query results + + Example: + # As viewer + result = await query_database( + auth={"api_key": "viewer_key"}, + query="SELECT * FROM users LIMIT 10" + ) + """ + return { + "success": True, + "user": auth.username, + "query": query, + "results": [ + {"id": 1, "name": "Alice", "role": "admin"}, + {"id": 2, "name": "Bob", "role": "editor"}, + {"id": 3, "name": "Charlie", "role": "viewer"}, + ], + } + + +@mcp.tool() +@requires_auth_async(provider=auth_provider) +@requires_manifest_async(manifest=manifest) +async def update_database( + auth: AuthContext, + record_id: int, + data: dict[str, Any], +) -> dict[str, Any]: + """ + Update a database record. + + Requires: editor or admin role + write:data permission + (as defined in manifest) + + Args: + auth: Authentication context (injected) + record_id: ID of record to update + data: New data for record + + Returns: + Update result + + Example: + # As editor + result = await update_database( + auth={"api_key": "editor_key"}, + record_id=2, + data={"name": "Bob Updated"} + ) + """ + return { + "success": True, + "user": auth.username, + "updated_record": record_id, + "data": data, + "message": "Record updated successfully", + } + + +@mcp.tool() +@requires_auth_async(provider=auth_provider) +@requires_manifest_async(manifest=manifest) +async def delete_all_data(auth: AuthContext, confirmation: str) -> dict[str, Any]: + """ + Delete ALL data (DANGEROUS operation). + + Requires: admin role + admin:all permission + admin:full scope + (as defined in manifest - marked as dangerous) + + Args: + auth: Authentication context (injected) + confirmation: Must be "I UNDERSTAND THIS IS PERMANENT" + + Returns: + Deletion result + + Example: + # As admin + result = await delete_all_data( + auth={"api_key": "admin_key"}, + confirmation="I UNDERSTAND THIS IS PERMANENT" + ) + """ + if confirmation != "I UNDERSTAND THIS IS PERMANENT": + return { + "success": False, + "error": "Invalid confirmation. This is a dangerous operation.", + } + + return { + "success": True, + "user": auth.username, + "message": "All data deleted (simulated - would require additional confirmation in production)", + "warning": "This operation is marked as DANGEROUS in the manifest", + } + + +@mcp.tool() +@requires_auth_async(provider=auth_provider) +@requires_manifest_async(manifest=manifest) +async def export_user_data( + auth: AuthContext, + format: str = "json", +) -> dict[str, Any]: + """ + Export user data for analysis. + + Requires: admin OR data_analyst role + (export:data OR read:data) permission + (as defined in manifest - uses OR logic for multiple requirements) + + Args: + auth: Authentication context (injected) + format: Export format (json, csv, xlsx) + + Returns: + Exported data + + Example: + # As data analyst + result = await export_user_data( + auth={"api_key": "analyst_key"}, + format="csv" + ) + """ + return { + "success": True, + "user": auth.username, + "format": format, + "data_url": f"/exports/users_{auth.user_id}.{format}", + "record_count": 1000, + "message": "Data export prepared successfully", + } + + +# Demonstration function +async def demonstrate_manifest(): + """Demonstrate manifest enforcement with different user roles.""" + print("\n" + "=" * 60) + print("MANIFEST ENFORCEMENT DEMONSTRATION") + print("=" * 60 + "\n") + + # Test 1: Viewer can query + print("Test 1: Viewer querying database (should succeed)") + try: + result = await query_database( + auth={"api_key": "viewer_key"}, + query="SELECT * FROM users", + ) + print(f"✓ Success: {result['user']} queried database\n") + except Exception as e: + print(f"✗ Failed: {e}\n") + + # Test 2: Viewer cannot update + print("Test 2: Viewer updating database (should fail)") + try: + result = await update_database( + auth={"api_key": "viewer_key"}, + record_id=1, + data={"name": "Updated"}, + ) + print(f"✓ Unexpected success\n") + except Exception as e: + print(f"✓ Correctly denied: {e}\n") + + # Test 3: Editor can update + print("Test 3: Editor updating database (should succeed)") + try: + result = await update_database( + auth={"api_key": "editor_key"}, + record_id=1, + data={"name": "Updated"}, + ) + print(f"✓ Success: {result['user']} updated record\n") + except Exception as e: + print(f"✗ Failed: {e}\n") + + # Test 4: Editor cannot delete + print("Test 4: Editor deleting all data (should fail)") + try: + result = await delete_all_data( + auth={"api_key": "editor_key"}, + confirmation="I UNDERSTAND THIS IS PERMANENT", + ) + print(f"✓ Unexpected success\n") + except Exception as e: + print(f"✓ Correctly denied: {e}\n") + + # Test 5: Admin can delete + print("Test 5: Admin deleting all data (should succeed)") + try: + result = await delete_all_data( + auth={"api_key": "admin_key"}, + confirmation="I UNDERSTAND THIS IS PERMANENT", + ) + print(f"✓ Success: {result['user']} performed dangerous operation\n") + except Exception as e: + print(f"✗ Failed: {e}\n") + + # Test 6: Data analyst can export + print("Test 6: Data analyst exporting data (should succeed)") + try: + result = await export_user_data( + auth={"api_key": "analyst_key"}, + format="csv", + ) + print(f"✓ Success: {result['user']} exported data\n") + except Exception as e: + print(f"✗ Failed: {e}\n") + + print("=" * 60) + print("MANIFEST SUMMARY") + print("=" * 60) + print(f"Total scopes defined: {len(manifest.scopes)}") + print(f"Total tools protected: {len(manifest.tools)}") + print(f"Dangerous tools: {sum(1 for t in manifest.tools.values() if t.dangerous)}") + print("\nProtected tools:") + for tool_name, tool_perm in manifest.tools.items(): + danger_flag = " [DANGEROUS]" if tool_perm.dangerous else "" + print(f" - {tool_name}{danger_flag}") + if tool_perm.roles: + print(f" Roles: {', '.join(tool_perm.roles)}") + if tool_perm.permissions: + print(f" Permissions: {', '.join(tool_perm.permissions)}") + if tool_perm.scopes: + print(f" Scopes: {', '.join(tool_perm.scopes)}") + print() + + +if __name__ == "__main__": + print("Starting Permission Manifest MCP Server...\n") + + # Run demonstration + asyncio.run(demonstrate_manifest()) + + print("\n" + "=" * 60) + print("KEY CONCEPTS") + print("=" * 60) + print("1. Declarative Security: Define requirements in YAML/code") + print("2. Automatic Enforcement: @requires_manifest_async decorator") + print("3. Flexible Requirements: Roles, permissions, and scopes") + print("4. OR Logic: User needs ANY ONE from each type") + print("5. AND Logic: User needs ALL types that are specified") + print("6. Dangerous Flags: Mark operations requiring extra care") + print("=" * 60) diff --git a/examples/auth/multi_provider_server.py b/examples/auth/multi_provider_server.py new file mode 100644 index 0000000..da32a7a --- /dev/null +++ b/examples/auth/multi_provider_server.py @@ -0,0 +1,425 @@ +""" +Multi-Provider OAuth Server Example + +This example demonstrates an MCP server that supports multiple OAuth providers: +- Google OAuth (for email and Drive access) +- GitHub OAuth (for repository access) + +Users can authenticate with either provider, and the server maintains separate +sessions for each provider. + +Run this server: + python examples/auth/multi_provider_server.py + +Features: +- Multiple OAuth provider support +- Provider-specific tools +- Cross-provider user identification +- Session management per provider +""" + +import os +from pathlib import Path + +from fastmcp import FastMCP + +from nextmcp.auth import ( + GitHubOAuthProvider, + GoogleOAuthProvider, + requires_scope_async, +) +from nextmcp.auth.core import AuthContext +from nextmcp.protocol import ( + AuthFlowType, + AuthMetadata, + AuthRequirement, +) +from nextmcp.session import FileSessionStore + +# ============================================================================ +# Configuration +# ============================================================================ + +# Get OAuth credentials from environment +GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID", "your-google-client-id") +GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET", "your-google-client-secret") + +GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID", "your-github-client-id") +GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET", "your-github-client-secret") + +# Session storage +SESSION_DIR = Path(".multi_provider_sessions") + + +# ============================================================================ +# Initialize Server +# ============================================================================ + +mcp = FastMCP("Multi-Provider OAuth Server") + + +# ============================================================================ +# Set Up OAuth Providers +# ============================================================================ + +google = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + redirect_uri="http://localhost:8080/oauth/callback", + scope=["openid", "email", "profile", "https://www.googleapis.com/auth/drive.readonly"], +) + +github = GitHubOAuthProvider( + client_id=GITHUB_CLIENT_ID, + client_secret=GITHUB_CLIENT_SECRET, + redirect_uri="http://localhost:8080/oauth/callback", + scope=["read:user", "repo"], +) + + +# ============================================================================ +# Set Up Session Store +# ============================================================================ + +session_store = FileSessionStore(SESSION_DIR) + + +# ============================================================================ +# Create Auth Metadata +# ============================================================================ + +auth_metadata = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + supports_multi_user=True, + token_refresh_enabled=True, +) + +# Add Google provider +auth_metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["openid", "email", "profile", "https://www.googleapis.com/auth/drive.readonly"], + supports_refresh=True, + supports_pkce=True, +) + +# Add GitHub provider +auth_metadata.add_provider( + name="github", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://github.com/login/oauth/authorize", + token_url="https://github.com/login/oauth/access_token", + scopes=["read:user", "repo"], + supports_refresh=False, # GitHub doesn't support refresh tokens + supports_pkce=True, +) + + +# ============================================================================ +# Public Tools +# ============================================================================ + + +@mcp.tool() +def get_auth_metadata() -> dict: + """ + Get server authentication requirements. + + Returns information about supported OAuth providers and required scopes. + + Returns: + dict: Complete authentication metadata including both Google and GitHub + """ + return auth_metadata.to_dict() + + +@mcp.tool() +def get_supported_providers() -> dict: + """ + List supported OAuth providers. + + Returns: + dict: Information about Google and GitHub OAuth providers + """ + return { + "providers": [ + { + "name": "google", + "display_name": "Google", + "features": ["Email", "Drive access", "Profile"], + "scopes": ["openid", "email", "profile", "drive.readonly"], + "supports_refresh": True, + }, + { + "name": "github", + "display_name": "GitHub", + "features": ["Profile", "Repository access"], + "scopes": ["read:user", "repo"], + "supports_refresh": False, + }, + ], + "notes": [ + "You can authenticate with either provider", + "Some tools require specific providers", + "Google supports token refresh, GitHub doesn't", + ], + } + + +# ============================================================================ +# Google-Specific Tools +# ============================================================================ + + +@mcp.tool() +async def get_google_profile() -> dict: + """ + Get user's Google profile. + + Requires Google OAuth authentication. + + Returns: + dict: User's email, name, and Google-specific info + """ + # In a real implementation, this would check the auth context + # to ensure user authenticated with Google + return { + "message": "This would return Google profile information", + "required_provider": "google", + "required_scopes": ["openid", "email", "profile"], + } + + +@mcp.tool() +async def list_google_drive_files() -> dict: + """ + List files from user's Google Drive. + + Requires Google OAuth with drive.readonly scope. + + Returns: + dict: List of files from Google Drive + """ + return { + "message": "This would list files from Google Drive", + "required_provider": "google", + "required_scopes": ["https://www.googleapis.com/auth/drive.readonly"], + } + + +# ============================================================================ +# GitHub-Specific Tools +# ============================================================================ + + +@mcp.tool() +async def get_github_profile() -> dict: + """ + Get user's GitHub profile. + + Requires GitHub OAuth authentication. + + Returns: + dict: User's GitHub username, repos, and profile info + """ + return { + "message": "This would return GitHub profile information", + "required_provider": "github", + "required_scopes": ["read:user"], + } + + +@mcp.tool() +async def list_github_repos(visibility: str = "all") -> dict: + """ + List user's GitHub repositories. + + Requires GitHub OAuth with repo scope. + + Args: + visibility: "all", "public", or "private" + + Returns: + dict: List of repositories + """ + return { + "message": f"This would list {visibility} repositories", + "required_provider": "github", + "required_scopes": ["repo"], + } + + +@mcp.tool() +async def create_github_issue(repo: str, title: str, body: str) -> dict: + """ + Create an issue on a GitHub repository. + + Requires GitHub OAuth with repo scope. + + Args: + repo: Repository name (owner/repo) + title: Issue title + body: Issue description + + Returns: + dict: Created issue information + """ + return { + "message": f"This would create issue '{title}' on {repo}", + "required_provider": "github", + "required_scopes": ["repo"], + } + + +# ============================================================================ +# Cross-Provider Tools +# ============================================================================ + + +@mcp.tool() +async def get_unified_profile() -> dict: + """ + Get unified user profile across providers. + + Works with either Google or GitHub authentication. + Returns provider-specific information based on which provider was used. + + Returns: + dict: Unified profile information + """ + # In a real implementation, this would: + # 1. Check which provider user authenticated with + # 2. Fetch profile from that provider + # 3. Return normalized profile data + return { + "message": "This would return profile from whichever provider user used", + "supports": ["google", "github"], + } + + +@mcp.tool() +async def link_provider_accounts() -> dict: + """ + Link Google and GitHub accounts for the same user. + + Allows a user to authenticate with both providers and link their accounts + for a unified experience. + + Returns: + dict: Account linking status + """ + return { + "message": "This would link Google and GitHub accounts", + "note": "User would need to authenticate with both providers", + "benefits": [ + "Access Drive files AND GitHub repos", + "Unified identity across providers", + "Seamless cross-service operations", + ], + } + + +# ============================================================================ +# Session Management +# ============================================================================ + + +@mcp.tool() +async def get_active_sessions_by_provider() -> dict: + """ + Get active sessions grouped by OAuth provider. + + Returns: + dict: Sessions grouped by provider (Google vs GitHub) + """ + all_sessions = session_store.list_users() + + # In a real implementation, you would: + # 1. Load each session + # 2. Check session.provider + # 3. Group by provider + + return { + "total_sessions": len(all_sessions), + "message": "This would group sessions by provider", + "note": "Session data includes 'provider' field", + } + + +# ============================================================================ +# Utility Functions +# ============================================================================ + + +def print_startup_message(): + """Print helpful startup information.""" + print("\n" + "=" * 70) + print("🚀 Multi-Provider OAuth Server Started") + print("=" * 70) + print("\nSupported OAuth Providers:") + print(" 🔵 Google OAuth") + print(" - Scopes: openid, email, profile, drive.readonly") + print(" - Refresh tokens: Yes") + print(" - Use for: Email, Drive, Profile") + print("\n ⚫ GitHub OAuth") + print(" - Scopes: read:user, repo") + print(" - Refresh tokens: No") + print(" - Use for: Repositories, Issues, Profile") + print("\nAvailable Tools:") + print(" 📋 Public:") + print(" - get_auth_metadata") + print(" - get_supported_providers") + print("\n 🔵 Google-specific:") + print(" - get_google_profile") + print(" - list_google_drive_files") + print("\n ⚫ GitHub-specific:") + print(" - get_github_profile") + print(" - list_github_repos") + print(" - create_github_issue") + print("\n 🔗 Cross-provider:") + print(" - get_unified_profile") + print(" - link_provider_accounts") + print(" - get_active_sessions_by_provider") + print("\nTo get OAuth tokens:") + print(" Google:") + print(" python examples/auth/oauth_token_helper.py --provider google") + print("\n GitHub:") + print(" python examples/auth/oauth_token_helper.py --provider github") + print("\n" + "=" * 70 + "\n") + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + + +def main(): + """Run the server.""" + # Print startup information + print_startup_message() + + # Ensure session directory exists + SESSION_DIR.mkdir(parents=True, exist_ok=True) + + # Note: In a real implementation, you would: + # 1. Create separate middleware for each provider + # 2. Route requests to appropriate provider based on token + # 3. Or use a single middleware with provider detection + + # For now, this is a demonstration of the concept + print("Note: This is a demonstration server.") + print("For actual multi-provider support, you would need:") + print(" - Provider detection from access token") + print(" - Separate middleware per provider") + print(" - Token routing logic") + + # Run server + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/examples/auth/oauth_token_helper.py b/examples/auth/oauth_token_helper.py new file mode 100644 index 0000000..980d5f0 --- /dev/null +++ b/examples/auth/oauth_token_helper.py @@ -0,0 +1,430 @@ +""" +OAuth Token Helper Script + +This script helps you obtain OAuth access tokens for testing integration tests. +It provides an interactive workflow to: +1. Generate authorization URLs +2. Handle OAuth callbacks +3. Exchange codes for tokens +4. Display environment variables to set + +Usage: + # Interactive mode - prompts for provider + python examples/auth/oauth_token_helper.py + + # Specify provider + python examples/auth/oauth_token_helper.py --provider github + python examples/auth/oauth_token_helper.py --provider google + + # Manual mode (no callback server) + python examples/auth/oauth_token_helper.py --provider github --manual +""" + +import argparse +import asyncio +import os +import sys +import webbrowser +from urllib.parse import parse_qs, urlparse + +from nextmcp.auth import GitHubOAuthProvider, GoogleOAuthProvider + + +def print_header(text): + """Print a formatted header.""" + print("\n" + "=" * 70) + print(text) + print("=" * 70) + + +def print_step(number, text): + """Print a step number and description.""" + print(f"\n📍 Step {number}: {text}") + print("-" * 70) + + +def print_success(text): + """Print a success message.""" + print(f"✓ {text}") + + +def print_error(text): + """Print an error message.""" + print(f"✗ ERROR: {text}") + + +def print_warning(text): + """Print a warning message.""" + print(f"⚠️ WARNING: {text}") + + +def print_info(text): + """Print an info message.""" + print(f"ℹ️ {text}") + + +async def run_callback_server(state, verifier): + """ + Run a simple HTTP server to handle OAuth callback. + + Returns the authorization code or None if failed. + """ + from aiohttp import web + + code_container = {"code": None, "error": None} + + async def oauth_callback(request): + """Handle OAuth callback.""" + # Get code from query parameters + code = request.query.get("code") + error = request.query.get("error") + callback_state = request.query.get("state") + + if error: + code_container["error"] = error + return web.Response( + text=f"❌ Authorization failed: {error}\n\nYou can close this window.", + content_type="text/plain", + ) + + if not code: + code_container["error"] = "No authorization code received" + return web.Response( + text="❌ No authorization code received\n\nYou can close this window.", + content_type="text/plain", + ) + + if callback_state != state: + code_container["error"] = "State mismatch - possible CSRF attack" + return web.Response( + text="❌ Security error: State mismatch\n\nYou can close this window.", + content_type="text/plain", + ) + + code_container["code"] = code + + return web.Response( + text="✅ Authorization successful!\n\nYou can close this window and return to the terminal.", + content_type="text/plain", + ) + + # Create and start server + app = web.Application() + app.router.add_get("/oauth/callback", oauth_callback) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + + print_info("Starting local callback server on http://localhost:8080") + print_info("Waiting for authorization...") + + await site.start() + + # Wait for callback (with timeout) + timeout = 300 # 5 minutes + for _ in range(timeout): + if code_container["code"] or code_container["error"]: + break + await asyncio.sleep(1) + + # Cleanup + await runner.cleanup() + + if code_container["error"]: + print_error(code_container["error"]) + return None + + return code_container["code"] + + +async def get_github_token(client_id, client_secret, manual_mode=False): + """ + Interactive workflow to get GitHub access token. + + Args: + client_id: GitHub OAuth app client ID + client_secret: GitHub OAuth app client secret + manual_mode: If True, don't start callback server + + Returns: + Access token or None + """ + print_header("GITHUB OAUTH TOKEN HELPER") + + provider = GitHubOAuthProvider( + client_id=client_id, + client_secret=client_secret, + redirect_uri="http://localhost:8080/oauth/callback", + scope=["read:user", "repo"], + ) + + print_step(1, "Generating authorization URL") + auth_data = provider.generate_authorization_url() + url = auth_data["url"] + state = auth_data["state"] + verifier = auth_data["verifier"] + + print_success("Authorization URL generated") + print(f"\n📋 Authorization URL:") + print(f" {url}\n") + + if manual_mode: + print_step(2, "Manual authorization") + print("Please visit the URL above and authorize the application.") + print("\nAfter authorization, you'll be redirected to:") + print(" http://localhost:8080/oauth/callback?code=CODE&state=STATE") + print("\nCopy the 'code' parameter from the URL and paste it below:") + + code = input("\n🔑 Enter authorization code: ").strip() + if not code: + print_error("No code provided") + return None + else: + print_step(2, "Opening browser for authorization") + print("Your browser will open automatically...") + print("Please authorize the application in your browser.") + + # Open browser + webbrowser.open(url) + + # Start callback server + code = await run_callback_server(state, verifier) + if not code: + print_error("Failed to get authorization code") + return None + + print_success(f"Authorization code received: {code[:20]}...") + + print_step(3, "Exchanging code for access token") + try: + token_data = await provider.exchange_code_for_token( + code=code, + state=state, + verifier=verifier, + ) + + access_token = token_data["access_token"] + print_success("Access token obtained!") + + print_step(4, "Testing token with GitHub API") + user_info = await provider.get_user_info(access_token) + print_success(f"Token works! Authenticated as: {user_info.get('login')}") + + # Display results + print_header("GITHUB TOKEN OBTAINED SUCCESSFULLY") + print(f"\n✅ Access Token: {access_token}") + print(f"✅ Token Type: {token_data.get('token_type', 'bearer')}") + print(f"✅ Scope: {token_data.get('scope', 'N/A')}") + + print("\n📋 Set these environment variables:") + print(f" export GITHUB_CLIENT_ID=\"{client_id}\"") + print(f" export GITHUB_CLIENT_SECRET=\"{client_secret}\"") + print(f" export GITHUB_ACCESS_TOKEN=\"{access_token}\"") + + print("\n💾 Or add to .env file:") + print(f" GITHUB_CLIENT_ID={client_id}") + print(f" GITHUB_CLIENT_SECRET={client_secret}") + print(f" GITHUB_ACCESS_TOKEN={access_token}") + + return access_token + + except Exception as e: + print_error(f"Token exchange failed: {e}") + return None + + +async def get_google_token(client_id, client_secret, manual_mode=False): + """ + Interactive workflow to get Google access token and refresh token. + + Args: + client_id: Google OAuth client ID + client_secret: Google OAuth client secret + manual_mode: If True, don't start callback server + + Returns: + Tuple of (access_token, refresh_token) or (None, None) + """ + print_header("GOOGLE OAUTH TOKEN HELPER") + + provider = GoogleOAuthProvider( + client_id=client_id, + client_secret=client_secret, + redirect_uri="http://localhost:8080/oauth/callback", + scope=[ + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/gmail.readonly", + ], + ) + + print_step(1, "Generating authorization URL with offline access") + auth_data = provider.generate_authorization_url() + url = auth_data["url"] + state = auth_data["state"] + verifier = auth_data["verifier"] + + print_success("Authorization URL generated") + print(f"\n📋 Authorization URL:") + print(f" {url}\n") + print_info("Note: This includes 'access_type=offline' for refresh tokens") + + if manual_mode: + print_step(2, "Manual authorization") + print("Please visit the URL above and authorize the application.") + print("\nAfter authorization, you'll be redirected to:") + print(" http://localhost:8080/oauth/callback?code=CODE&state=STATE") + print("\nCopy the 'code' parameter from the URL and paste it below:") + + code = input("\n🔑 Enter authorization code: ").strip() + if not code: + print_error("No code provided") + return None, None + else: + print_step(2, "Opening browser for authorization") + print("Your browser will open automatically...") + print("Please sign in and authorize the application.") + + # Open browser + webbrowser.open(url) + + # Start callback server + code = await run_callback_server(state, verifier) + if not code: + print_error("Failed to get authorization code") + return None, None + + print_success(f"Authorization code received: {code[:20]}...") + + print_step(3, "Exchanging code for access and refresh tokens") + try: + token_data = await provider.exchange_code_for_token( + code=code, + state=state, + verifier=verifier, + ) + + access_token = token_data["access_token"] + refresh_token = token_data.get("refresh_token") + + print_success("Tokens obtained!") + + if not refresh_token: + print_warning("No refresh token received - you may need to revoke access and try again") + print_info("Refresh tokens are only issued on first authorization or with prompt=consent") + + print_step(4, "Testing token with Google API") + user_info = await provider.get_user_info(access_token) + print_success(f"Token works! Authenticated as: {user_info.get('email', 'Unknown')}") + + # Display results + print_header("GOOGLE TOKENS OBTAINED SUCCESSFULLY") + print(f"\n✅ Access Token: {access_token[:50]}...") + print(f"✅ Token Type: {token_data.get('token_type', 'Bearer')}") + print(f"✅ Expires In: {token_data.get('expires_in', 'Unknown')} seconds") + print(f"✅ Scope: {token_data.get('scope', 'N/A')}") + + if refresh_token: + print(f"✅ Refresh Token: {refresh_token[:50]}...") + else: + print(f"⚠️ Refresh Token: Not issued") + + print("\n📋 Set these environment variables:") + print(f" export GOOGLE_CLIENT_ID=\"{client_id}\"") + print(f" export GOOGLE_CLIENT_SECRET=\"{client_secret}\"") + print(f" export GOOGLE_ACCESS_TOKEN=\"{access_token}\"") + if refresh_token: + print(f" export GOOGLE_REFRESH_TOKEN=\"{refresh_token}\"") + + print("\n💾 Or add to .env file:") + print(f" GOOGLE_CLIENT_ID={client_id}") + print(f" GOOGLE_CLIENT_SECRET={client_secret}") + print(f" GOOGLE_ACCESS_TOKEN={access_token}") + if refresh_token: + print(f" GOOGLE_REFRESH_TOKEN={refresh_token}") + + return access_token, refresh_token + + except Exception as e: + print_error(f"Token exchange failed: {e}") + import traceback + traceback.print_exc() + return None, None + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="OAuth Token Helper - Obtain tokens for integration testing" + ) + parser.add_argument( + "--provider", + choices=["github", "google"], + help="OAuth provider (github or google)", + ) + parser.add_argument( + "--manual", + action="store_true", + help="Manual mode (no callback server, paste code manually)", + ) + + args = parser.parse_args() + + # Welcome message + print_header("OAUTH TOKEN HELPER FOR INTEGRATION TESTING") + print("\nThis script helps you obtain OAuth access tokens for testing.") + print("You'll need OAuth app credentials from GitHub or Google.") + print("\nFor setup instructions, see: docs/OAUTH_TESTING_SETUP.md") + + # Determine provider + provider = args.provider + if not provider: + print("\n🔧 Select OAuth Provider:") + print(" 1. GitHub") + print(" 2. Google") + choice = input("\nEnter choice (1 or 2): ").strip() + + if choice == "1": + provider = "github" + elif choice == "2": + provider = "google" + else: + print_error("Invalid choice") + sys.exit(1) + + # Get credentials + if provider == "github": + client_id = os.getenv("GITHUB_CLIENT_ID") or input("\n🔑 GitHub Client ID: ").strip() + client_secret = os.getenv("GITHUB_CLIENT_SECRET") or input("🔑 GitHub Client Secret: ").strip() + + if not client_id or not client_secret: + print_error("Client ID and Secret are required") + print_info("Get them from: https://github.com/settings/developers") + sys.exit(1) + + asyncio.run(get_github_token(client_id, client_secret, args.manual)) + + elif provider == "google": + client_id = os.getenv("GOOGLE_CLIENT_ID") or input("\n🔑 Google Client ID: ").strip() + client_secret = os.getenv("GOOGLE_CLIENT_SECRET") or input("🔑 Google Client Secret: ").strip() + + if not client_id or not client_secret: + print_error("Client ID and Secret are required") + print_info("Get them from: https://console.cloud.google.com") + sys.exit(1) + + asyncio.run(get_google_token(client_id, client_secret, args.manual)) + + print("\n" + "=" * 70) + print("Next Steps:") + print("=" * 70) + print("1. Copy the export commands above and run them in your terminal") + print("2. Or add them to your .env file and run: export $(cat .env | xargs)") + print("3. Run integration tests: pytest tests/test_oauth_integration.py -v -m integration") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/auth/session_management_example.py b/examples/auth/session_management_example.py new file mode 100644 index 0000000..4496e45 --- /dev/null +++ b/examples/auth/session_management_example.py @@ -0,0 +1,347 @@ +""" +Session Management Example + +This example demonstrates advanced session management features: +- Manual session creation and management +- Session inspection and monitoring +- Token refresh workflows +- Session cleanup strategies +- Custom session metadata + +This is useful for understanding how sessions work internally and for +building custom session management logic. + +Run this example: + python examples/auth/session_management_example.py +""" + +import asyncio +import time +from pathlib import Path + +from nextmcp.session import ( + FileSessionStore, + MemorySessionStore, + SessionData, +) + + +async def main(): + """Run session management demonstrations.""" + print("\n" + "=" * 70) + print("NextMCP Session Management Examples") + print("=" * 70 + "\n") + + # ======================================================================== + # Example 1: Basic Session Operations + # ======================================================================== + + print("📋 Example 1: Basic Session Operations") + print("-" * 70) + + # Create session store (using memory for demo) + session_store = MemorySessionStore() + + # Create a session + session = SessionData( + user_id="user123", + access_token="ya29.a0ATi6K2example_access_token", + refresh_token="1//01example_refresh_token", + expires_at=time.time() + 3600, # Expires in 1 hour + scopes=["openid", "email", "profile"], + user_info={ + "email": "user@example.com", + "name": "John Doe", + "picture": "https://example.com/photo.jpg", + }, + provider="google", + ) + + # Save session + session_store.save(session) + print(f"✓ Created session for user: {session.user_id}") + print(f" Provider: {session.provider}") + print(f" Scopes: {', '.join(session.scopes)}") + print(f" Expires: in {int((session.expires_at - time.time()) / 60)} minutes") + + # Load session + loaded = session_store.load("user123") + print(f"\n✓ Loaded session for user: {loaded.user_id}") + print(f" Email: {loaded.user_info.get('email')}") + print(f" Name: {loaded.user_info.get('name')}") + + # ======================================================================== + # Example 2: Token Expiration Handling + # ======================================================================== + + print("\n\n📋 Example 2: Token Expiration Handling") + print("-" * 70) + + # Create session expiring soon + expiring_session = SessionData( + user_id="user456", + access_token="token_expires_soon", + expires_at=time.time() + 120, # Expires in 2 minutes + scopes=["profile"], + provider="google", + ) + + session_store.save(expiring_session) + + # Check expiration + print(f"Token expired? {expiring_session.is_expired()}") + print(f"Needs refresh (5 min buffer)? {expiring_session.needs_refresh()}") + print(f"Needs refresh (1 min buffer)? {expiring_session.needs_refresh(buffer_seconds=60)}") + + # Create already-expired session + expired_session = SessionData( + user_id="user789", + access_token="token_expired", + expires_at=time.time() - 10, # Expired 10 seconds ago + scopes=["profile"], + provider="google", + ) + + session_store.save(expired_session) + + print(f"\nExpired token check: {expired_session.is_expired()}") + + # Clean up expired sessions + cleaned = session_store.cleanup_expired() + print(f"\n✓ Cleaned up {cleaned} expired session(s)") + print(f" Remaining sessions: {len(session_store.list_users())}") + + # ======================================================================== + # Example 3: Updating Tokens (Refresh Flow) + # ======================================================================== + + print("\n\n📋 Example 3: Updating Tokens (Refresh Flow)") + print("-" * 70) + + # Simulate token refresh + print("Simulating token refresh for user123...") + + # Load existing session + session = session_store.load("user123") + old_token = session.access_token + + # Update with new tokens (simulating OAuth refresh) + session_store.update_tokens( + user_id="user123", + access_token="ya29.a0NEW_ACCESS_TOKEN_after_refresh", + refresh_token="1//01NEW_REFRESH_TOKEN", + expires_in=3600, # New expiration (1 hour from now) + ) + + # Verify update + updated = session_store.load("user123") + print(f"✓ Token refreshed") + print(f" Old token: {old_token[:20]}...") + print(f" New token: {updated.access_token[:20]}...") + print(f" Expires in: {int((updated.expires_at - time.time()) / 60)} minutes") + + # ======================================================================== + # Example 4: Custom Session Metadata + # ======================================================================== + + print("\n\n📋 Example 4: Custom Session Metadata") + print("-" * 70) + + # Create session with custom metadata + session_with_metadata = SessionData( + user_id="poweruser", + access_token="token_with_metadata", + scopes=["admin"], + provider="google", + metadata={ + "subscription": "premium", + "preferences": { + "theme": "dark", + "notifications": True, + }, + "usage_stats": { + "requests_today": 42, + "last_request": time.time(), + }, + "roles": ["admin", "developer"], + }, + ) + + session_store.save(session_with_metadata) + + # Retrieve and use metadata + loaded = session_store.load("poweruser") + print("✓ Session with custom metadata:") + print(f" Subscription: {loaded.metadata.get('subscription')}") + print(f" Theme: {loaded.metadata.get('preferences', {}).get('theme')}") + print(f" Roles: {loaded.metadata.get('roles')}") + print(f" Requests today: {loaded.metadata.get('usage_stats', {}).get('requests_today')}") + + # ======================================================================== + # Example 5: File-Based Session Persistence + # ======================================================================== + + print("\n\n📋 Example 5: File-Based Session Persistence") + print("-" * 70) + + # Create file-based session store + file_store = FileSessionStore(".example_sessions") + + # Create sessions + for i in range(3): + session = SessionData( + user_id=f"file_user_{i}", + access_token=f"token_{i}", + scopes=["profile"], + provider="google", + user_info={"email": f"user{i}@example.com"}, + ) + file_store.save(session) + + print(f"✓ Created {len(file_store.list_users())} file-based sessions") + print(f" Storage location: {file_store.directory.absolute()}") + + # List files + files = list(file_store.directory.glob("session_*.json")) + print(f" Files created: {len(files)}") + for file in files: + print(f" - {file.name}") + + # Test persistence across instances + file_store2 = FileSessionStore(".example_sessions") + users = file_store2.list_users() + print(f"\n✓ Loaded {len(users)} sessions from disk (different instance)") + + # Cleanup + file_store.clear_all() + print(f"✓ Cleaned up file-based sessions") + + # ======================================================================== + # Example 6: Multi-User Session Management + # ======================================================================== + + print("\n\n📋 Example 6: Multi-User Session Management") + print("-" * 70) + + # Create sessions for multiple users + users = [ + ("alice", "google", ["email", "profile"]), + ("bob", "github", ["read:user", "repo"]), + ("charlie", "google", ["email", "profile", "drive.readonly"]), + ("diana", "github", ["read:user"]), + ] + + store = MemorySessionStore() + + for user_id, provider, scopes in users: + session = SessionData( + user_id=user_id, + access_token=f"token_{user_id}", + scopes=scopes, + provider=provider, + user_info={"username": user_id}, + ) + store.save(session) + + print(f"✓ Created {len(store.list_users())} user sessions") + + # Group by provider + google_users = [] + github_users = [] + + for user_id in store.list_users(): + session = store.load(user_id) + if session.provider == "google": + google_users.append(user_id) + else: + github_users.append(user_id) + + print(f"\n Google users: {google_users}") + print(f" GitHub users: {github_users}") + + # ======================================================================== + # Example 7: Session Monitoring + # ======================================================================== + + print("\n\n📋 Example 7: Session Monitoring") + print("-" * 70) + + # Create mix of sessions with different states + monitoring_store = MemorySessionStore() + + # Active session + active = SessionData( + user_id="active_user", + access_token="token_active", + expires_at=time.time() + 3600, + scopes=["profile"], + provider="google", + ) + monitoring_store.save(active) + + # Expiring soon + expiring = SessionData( + user_id="expiring_user", + access_token="token_expiring", + expires_at=time.time() + 120, # 2 minutes + scopes=["profile"], + provider="google", + ) + monitoring_store.save(expiring) + + # Expired + expired = SessionData( + user_id="expired_user", + access_token="token_expired", + expires_at=time.time() - 10, + scopes=["profile"], + provider="google", + ) + monitoring_store.save(expired) + + # Monitor sessions + all_users = monitoring_store.list_users() + print(f"Total sessions: {len(all_users)}") + + active_count = 0 + expiring_count = 0 + expired_count = 0 + + for user_id in all_users: + session = monitoring_store.load(user_id) + if session.is_expired(): + expired_count += 1 + print(f" ❌ {user_id}: Expired") + elif session.needs_refresh(): + expiring_count += 1 + print(f" ⚠️ {user_id}: Expiring soon") + else: + active_count += 1 + print(f" ✓ {user_id}: Active") + + print(f"\nSummary:") + print(f" Active: {active_count}") + print(f" Expiring soon: {expiring_count}") + print(f" Expired: {expired_count}") + + # ======================================================================== + # Cleanup + # ======================================================================== + + print("\n\n📋 Cleanup") + print("-" * 70) + + # Clean up example session directory + if Path(".example_sessions").exists(): + import shutil + + shutil.rmtree(".example_sessions") + print("✓ Removed example session directory") + + print("\n" + "=" * 70) + print("Examples completed!") + print("=" * 70 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/nextmcp/auth/__init__.py b/nextmcp/auth/__init__.py index f156106..5324177 100644 --- a/nextmcp/auth/__init__.py +++ b/nextmcp/auth/__init__.py @@ -3,6 +3,12 @@ This module provides a comprehensive auth system inspired by next-auth, adapted for the Model Context Protocol (MCP). + +Includes support for: +- API Key, JWT, and Session authentication +- OAuth 2.0 with PKCE (GitHub, Google, and custom providers) +- Role-Based Access Control (RBAC) +- Fine-grained permissions """ from nextmcp.auth.core import ( @@ -12,20 +18,36 @@ Permission, Role, ) +from nextmcp.auth.errors import ( + AuthenticationError, + AuthorizationError, + ManifestViolationError, + OAuthRequiredError, + ScopeInsufficientError, +) +from nextmcp.auth.manifest import PermissionManifest, ScopeDefinition, ToolPermission from nextmcp.auth.middleware import ( requires_auth, requires_auth_async, + requires_manifest_async, requires_permission, requires_permission_async, requires_role, requires_role_async, + requires_scope_async, ) +from nextmcp.auth.oauth import OAuthConfig, OAuthProvider, PKCEChallenge +from nextmcp.auth.oauth_providers import GitHubOAuthProvider, GoogleOAuthProvider from nextmcp.auth.providers import ( APIKeyProvider, JWTProvider, SessionProvider, ) from nextmcp.auth.rbac import RBAC, PermissionDeniedError +from nextmcp.auth.request_middleware import ( + AuthEnforcementMiddleware, + create_auth_middleware, +) __all__ = [ # Core @@ -34,18 +56,39 @@ "AuthResult", "Permission", "Role", - # Middleware + # Errors + "AuthenticationError", + "AuthorizationError", + "OAuthRequiredError", + "ScopeInsufficientError", + "ManifestViolationError", + "PermissionDeniedError", + # Manifest + "PermissionManifest", + "ScopeDefinition", + "ToolPermission", + # Middleware (decorators) "requires_auth", "requires_auth_async", + "requires_manifest_async", "requires_permission", "requires_permission_async", "requires_role", "requires_role_async", + "requires_scope_async", + # Request Middleware (runtime enforcement) + "AuthEnforcementMiddleware", + "create_auth_middleware", # Providers "APIKeyProvider", "JWTProvider", "SessionProvider", + # OAuth + "OAuthProvider", + "OAuthConfig", + "PKCEChallenge", + "GitHubOAuthProvider", + "GoogleOAuthProvider", # RBAC "RBAC", - "PermissionDeniedError", ] diff --git a/nextmcp/auth/core.py b/nextmcp/auth/core.py index 8c93c3d..fdc4383 100644 --- a/nextmcp/auth/core.py +++ b/nextmcp/auth/core.py @@ -98,7 +98,7 @@ class AuthContext: Represents the authentication context for a request. This contains information about the authenticated user, their credentials, - roles, and permissions. It's passed to tools that require authentication. + roles, permissions, and OAuth scopes. It's passed to tools that require authentication. """ authenticated: bool = False @@ -106,6 +106,7 @@ class AuthContext: username: str | None = None roles: set[Role] = field(default_factory=set) permissions: set[Permission] = field(default_factory=set) + scopes: set[str] = field(default_factory=set) # OAuth scopes metadata: dict[str, Any] = field(default_factory=dict) def has_role(self, role_name: str) -> bool: @@ -125,6 +126,18 @@ def has_permission(self, permission_name: str) -> bool: # Check role permissions return any(r.has_permission(permission_name) for r in self.roles) + def has_scope(self, scope_name: str) -> bool: + """ + Check if user has a specific OAuth scope. + + Args: + scope_name: Scope name to check + + Returns: + True if user has the scope, False otherwise + """ + return scope_name in self.scopes + def add_role(self, role: Role | str) -> None: """Add a role to this auth context.""" if isinstance(role, str): @@ -137,6 +150,15 @@ def add_permission(self, permission: Permission | str) -> None: permission = Permission(permission) self.permissions.add(permission) + def add_scope(self, scope: str) -> None: + """ + Add an OAuth scope to this auth context. + + Args: + scope: Scope string to add + """ + self.scopes.add(scope) + @dataclass class AuthResult: diff --git a/nextmcp/auth/errors.py b/nextmcp/auth/errors.py new file mode 100644 index 0000000..7081456 --- /dev/null +++ b/nextmcp/auth/errors.py @@ -0,0 +1,206 @@ +""" +Specialized authentication and authorization error types. + +This module provides clear, structured exceptions for different auth failure scenarios: +- AuthenticationError: General authentication failure +- AuthorizationError: General authorization failure +- OAuthRequiredError: OAuth authentication is needed +- ScopeInsufficientError: User lacks required OAuth scopes +- ManifestViolationError: Permission manifest access check failed +""" + +from typing import Any + +from nextmcp.auth.core import AuthContext + + +class AuthenticationError(Exception): + """ + Raised when authentication fails. + + This is a general authentication error for any auth failure. + + Attributes: + message: Human-readable error message + required_scopes: OAuth scopes required (optional) + providers: Available auth providers (optional) + """ + + def __init__( + self, + message: str, + required_scopes: list[str] | None = None, + providers: list[Any] | None = None, + ): + """ + Initialize AuthenticationError. + + Args: + message: Error message + required_scopes: Required OAuth scopes + providers: Available authentication providers + """ + super().__init__(message) + self.message = message + self.required_scopes = required_scopes or [] + self.providers = providers or [] + + +class AuthorizationError(Exception): + """ + Raised when authorization fails. + + This is a general authorization error for permission/access denials. + + Attributes: + message: Human-readable error message + required: What was required for access + user_id: User ID who was denied + """ + + def __init__( + self, + message: str, + required: Any | None = None, + user_id: str | None = None, + ): + """ + Initialize AuthorizationError. + + Args: + message: Error message + required: Required permissions/scopes/roles + user_id: User ID + """ + super().__init__(message) + self.message = message + self.required = required + self.user_id = user_id + + +class OAuthRequiredError(Exception): + """ + Raised when OAuth authentication is required but not provided. + + This error indicates that a tool or operation requires OAuth authentication. + It can include the authorization URL to help users complete the OAuth flow. + + Attributes: + message: Human-readable error message + provider: OAuth provider name (e.g., "github", "google") + scopes: Required OAuth scopes + authorization_url: URL to initiate OAuth flow + user_id: Current user ID (if partially authenticated) + """ + + def __init__( + self, + message: str, + provider: str | None = None, + scopes: list[str] | None = None, + authorization_url: str | None = None, + user_id: str | None = None, + ): + """ + Initialize OAuthRequiredError. + + Args: + message: Error message + provider: OAuth provider name + scopes: Required OAuth scopes + authorization_url: URL to start OAuth flow + user_id: Current user ID + """ + super().__init__(message) + self.message = message + self.provider = provider + self.scopes = scopes or [] + self.authorization_url = authorization_url + self.user_id = user_id + + +class ScopeInsufficientError(Exception): + """ + Raised when user lacks required OAuth scopes. + + This error indicates that the user is authenticated but doesn't have + sufficient OAuth scopes to perform the requested operation. + + Attributes: + message: Human-readable error message + required_scopes: List of scopes required (user needs ANY one) + current_scopes: List of scopes user currently has + user_id: User ID who lacks scopes + """ + + def __init__( + self, + message: str, + required_scopes: list[str] | None = None, + current_scopes: list[str] | None = None, + user_id: str | None = None, + ): + """ + Initialize ScopeInsufficientError. + + Args: + message: Error message + required_scopes: Scopes that are required + current_scopes: Scopes the user currently has + user_id: User ID + """ + super().__init__(message) + self.message = message + self.required_scopes = required_scopes or [] + self.current_scopes = current_scopes or [] + self.user_id = user_id + + +class ManifestViolationError(Exception): + """ + Raised when permission manifest access check fails. + + This error indicates that the user attempted to access a tool but failed + the manifest-based access control check. The error includes details about + what was required and what the user had. + + Attributes: + message: Human-readable error message + tool_name: Name of the tool that was denied + required_roles: Roles that are required (user needs ANY one) + required_permissions: Permissions required (user needs ANY one) + required_scopes: OAuth scopes required (user needs ANY one) + user_id: User ID who was denied access + auth_context: Full authentication context for debugging + """ + + def __init__( + self, + message: str, + tool_name: str | None = None, + required_roles: list[str] | None = None, + required_permissions: list[str] | None = None, + required_scopes: list[str] | None = None, + user_id: str | None = None, + auth_context: AuthContext | None = None, + ): + """ + Initialize ManifestViolationError. + + Args: + message: Error message + tool_name: Tool that was denied + required_roles: Roles required for access + required_permissions: Permissions required for access + required_scopes: OAuth scopes required for access + user_id: User ID + auth_context: Full auth context + """ + super().__init__(message) + self.message = message + self.tool_name = tool_name + self.required_roles = required_roles or [] + self.required_permissions = required_permissions or [] + self.required_scopes = required_scopes or [] + self.user_id = user_id + self.auth_context = auth_context diff --git a/nextmcp/auth/manifest.py b/nextmcp/auth/manifest.py new file mode 100644 index 0000000..562d621 --- /dev/null +++ b/nextmcp/auth/manifest.py @@ -0,0 +1,277 @@ +""" +Permission Manifest system for NextMCP. + +This module provides declarative security definitions using manifests, +allowing tools to specify their permission, role, and scope requirements +in a structured YAML/JSON format. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from nextmcp.auth.core import AuthContext + + +@dataclass +class ScopeDefinition: + """ + Defines an OAuth scope with metadata and provider mappings. + + Scopes can be mapped to provider-specific OAuth scopes for + multi-provider support (e.g., GitHub repo:read -> Google drive.readonly). + """ + + name: str + description: str + oauth_mapping: dict[str, list[str]] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Export scope definition to dictionary.""" + return { + "name": self.name, + "description": self.description, + "oauth_mapping": self.oauth_mapping, + } + + +@dataclass +class ToolPermission: + """ + Defines permission requirements for a tool. + + A tool can require: + - Permissions: Fine-grained permission strings (e.g., "read:data") + - Scopes: OAuth scopes (e.g., "repo:read") + - Roles: Role names (e.g., "admin") + + All requirement types use OR logic within their type, + but AND logic between types (must satisfy all types that are specified). + """ + + tool_name: str + permissions: list[str] = field(default_factory=list) + scopes: list[str] = field(default_factory=list) + roles: list[str] = field(default_factory=list) + description: str = "" + dangerous: bool = False + + def to_dict(self) -> dict[str, Any]: + """Export tool permission to dictionary.""" + return { + "permissions": self.permissions, + "scopes": self.scopes, + "roles": self.roles, + "description": self.description, + "dangerous": self.dangerous, + } + + +class PermissionManifest: + """ + Permission manifest for declarative security definitions. + + Manifests define: + 1. Scopes: OAuth scope definitions with provider mappings + 2. Tools: Tool permission requirements + + Can be loaded from YAML/JSON files or defined programmatically. + Used to enforce access control on MCP tools at runtime. + """ + + def __init__(self) -> None: + """Initialize an empty permission manifest.""" + self.scopes: dict[str, ScopeDefinition] = {} + self.tools: dict[str, ToolPermission] = {} + + def define_scope( + self, + name: str, + description: str, + oauth_mapping: dict[str, list[str]] | None = None, + ) -> ScopeDefinition: + """ + Define a scope in the manifest. + + Args: + name: Scope name (e.g., "read:data") + description: Human-readable description + oauth_mapping: Provider-specific OAuth scope mappings + + Returns: + The created ScopeDefinition + """ + scope = ScopeDefinition( + name=name, + description=description, + oauth_mapping=oauth_mapping or {}, + ) + self.scopes[name] = scope + return scope + + def define_tool_permission( + self, + tool_name: str, + permissions: list[str] | None = None, + scopes: list[str] | None = None, + roles: list[str] | None = None, + description: str = "", + dangerous: bool = False, + ) -> ToolPermission: + """ + Define permission requirements for a tool. + + Args: + tool_name: Name of the tool + permissions: Required permissions (user needs ANY one) + scopes: Required OAuth scopes (user needs ANY one) + roles: Required roles (user needs ANY one) + description: Human-readable description + dangerous: Whether this tool is dangerous (requires extra confirmation) + + Returns: + The created ToolPermission + """ + tool = ToolPermission( + tool_name=tool_name, + permissions=permissions or [], + scopes=scopes or [], + roles=roles or [], + description=description, + dangerous=dangerous, + ) + self.tools[tool_name] = tool + return tool + + def load_from_dict(self, data: dict[str, Any]) -> None: + """ + Load manifest from a dictionary. + + Expected format: + { + "scopes": [ + { + "name": "read:data", + "description": "Read data", + "oauth_mapping": {"github": ["repo:read"]} + } + ], + "tools": { + "query_db": { + "permissions": ["read:data"], + "scopes": ["db.query.read"], + "roles": ["viewer"], + "description": "Query database", + "dangerous": false + } + } + } + + Args: + data: Dictionary containing manifest data + """ + # Load scopes + for scope_data in data.get("scopes", []): + self.define_scope( + name=scope_data["name"], + description=scope_data.get("description", ""), + oauth_mapping=scope_data.get("oauth_mapping", {}), + ) + + # Load tools + for tool_name, tool_data in data.get("tools", {}).items(): + self.define_tool_permission( + tool_name=tool_name, + permissions=tool_data.get("permissions", []), + scopes=tool_data.get("scopes", []), + roles=tool_data.get("roles", []), + description=tool_data.get("description", ""), + dangerous=tool_data.get("dangerous", False), + ) + + def load_from_yaml(self, path: str) -> None: + """ + Load manifest from a YAML file. + + Args: + path: Path to YAML file + + Raises: + FileNotFoundError: If file doesn't exist + """ + import yaml + + yaml_path = Path(path) + if not yaml_path.exists(): + raise FileNotFoundError(f"Manifest file not found: {path}") + + with open(yaml_path) as f: + data = yaml.safe_load(f) + + self.load_from_dict(data or {}) + + def to_dict(self) -> dict[str, Any]: + """ + Export manifest to dictionary. + + Returns: + Dictionary containing all scopes and tools + """ + return { + "scopes": [scope.to_dict() for scope in self.scopes.values()], + "tools": {name: tool.to_dict() for name, tool in self.tools.items()}, + } + + def check_tool_access(self, tool_name: str, context: AuthContext) -> tuple[bool, str | None]: + """ + Check if an auth context has access to a tool. + + Logic: + - If tool not in manifest, allow access (unrestricted) + - If tool has no requirements, allow access + - If tool has requirements, must satisfy ALL requirement types (AND) + - Within each type (roles/permissions/scopes), need ANY one (OR) + + Args: + tool_name: Name of the tool to check + context: Authentication context to check + + Returns: + Tuple of (allowed: bool, error_message: str | None) + """ + # If tool not defined in manifest, allow access (no restrictions) + if tool_name not in self.tools: + return (True, None) + + tool = self.tools[tool_name] + + # If tool has no requirements, allow access + if not tool.roles and not tool.permissions and not tool.scopes: + return (True, None) + + # Check each requirement type (AND logic between types) + # Must satisfy all types that have requirements + + # Check roles (if specified) + if tool.roles: + has_required_role = any(context.has_role(role) for role in tool.roles) + if not has_required_role: + roles_str = ", ".join(tool.roles) + return (False, f"One of the following roles required: {roles_str}") + + # Check permissions (if specified) + if tool.permissions: + has_required_permission = any(context.has_permission(perm) for perm in tool.permissions) + if not has_required_permission: + perms_str = ", ".join(tool.permissions) + return (False, f"One of the following permissions required: {perms_str}") + + # Check scopes (if specified) + if tool.scopes: + has_required_scope = any(context.has_scope(scope) for scope in tool.scopes) + if not has_required_scope: + scopes_str = ", ".join(tool.scopes) + return (False, f"One of the following scopes required: {scopes_str}") + + # All requirements satisfied + return (True, None) diff --git a/nextmcp/auth/middleware.py b/nextmcp/auth/middleware.py index 3fd15b9..703f584 100644 --- a/nextmcp/auth/middleware.py +++ b/nextmcp/auth/middleware.py @@ -8,7 +8,10 @@ import functools import logging from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from nextmcp.auth.manifest import PermissionManifest from nextmcp.auth.core import AuthContext, AuthProvider from nextmcp.auth.rbac import PermissionDeniedError @@ -353,5 +356,138 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator +def requires_scope_async(*required_scopes: str) -> Callable: + """ + Async middleware decorator that requires specific OAuth scopes. + + Must be used with @requires_auth_async. + The auth context from the auth middleware is checked for required scopes. + + Args: + *required_scopes: Scope names required (user must have at least one) + + Example: + @app.tool() + @requires_auth_async(provider=github_oauth) + @requires_scope_async("repo:read", "repo:write") + async def access_repo(auth: AuthContext) -> dict: + return {"status": "authorized"} + """ + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # First argument should be AuthContext (from requires_auth_async) + if not args or not isinstance(args[0], AuthContext): + raise AuthenticationError( + "requires_scope_async must be used with requires_auth_async" + ) + + auth_context = args[0] + + # Check if user has any of the required scopes + has_scope = any(auth_context.has_scope(scope) for scope in required_scopes) + + if not has_scope: + scopes_str = ", ".join(required_scopes) + raise PermissionDeniedError( + f"One of the following scopes required: {scopes_str}", + required=scopes_str, + user_id=auth_context.user_id, + ) + + import asyncio + + if asyncio.iscoroutinefunction(fn): + return await fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + # Mark function as requiring scopes + wrapper._requires_scopes = required_scopes # type: ignore + + return wrapper + + return decorator + + +def requires_manifest_async( + manifest: "PermissionManifest | None" = None, + tool_name: str | None = None, +) -> Callable: + """ + Async middleware decorator that enforces PermissionManifest access control. + + Must be used with @requires_auth_async. + The auth context is checked against the manifest's tool requirements. + + Args: + manifest: PermissionManifest to enforce + tool_name: Name of tool to check (if None, uses function name) + + Example: + manifest = PermissionManifest() + manifest.define_tool_permission("admin_tool", roles=["admin"]) + + @app.tool() + @requires_auth_async(provider=api_key_provider) + @requires_manifest_async(manifest=manifest, tool_name="admin_tool") + async def admin_tool(auth: AuthContext) -> str: + return "Admin action performed" + """ + from nextmcp.auth.errors import ManifestViolationError + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # First argument should be AuthContext (from requires_auth_async) + if not args or not isinstance(args[0], AuthContext): + raise AuthenticationError( + "requires_manifest_async must be used with requires_auth_async" + ) + + auth_context = args[0] + + if manifest is None: + raise AuthenticationError("No manifest configured for requires_manifest_async") + + # Determine tool name (use parameter or function name) + actual_tool_name = tool_name if tool_name else fn.__name__ + + # Check manifest access + allowed, error_message = manifest.check_tool_access(actual_tool_name, auth_context) + + if not allowed: + # Get tool definition for error details + tool_def = manifest.tools.get(actual_tool_name) + + raise ManifestViolationError( + message=error_message or "Access denied by manifest", + tool_name=actual_tool_name, + required_roles=tool_def.roles if tool_def else [], + required_permissions=tool_def.permissions if tool_def else [], + required_scopes=tool_def.scopes if tool_def else [], + user_id=auth_context.user_id, + auth_context=auth_context, + ) + + # Access allowed - execute function + import asyncio + + if asyncio.iscoroutinefunction(fn): + return await fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + # Mark function as requiring manifest + wrapper._requires_manifest = True # type: ignore + wrapper._manifest = manifest # type: ignore + wrapper._tool_name = tool_name # type: ignore + + return wrapper + + return decorator + + # Need to add this import import asyncio # noqa: E402 diff --git a/nextmcp/auth/oauth.py b/nextmcp/auth/oauth.py new file mode 100644 index 0000000..67321c0 --- /dev/null +++ b/nextmcp/auth/oauth.py @@ -0,0 +1,331 @@ +""" +OAuth 2.0 authentication providers for NextMCP. + +Implements OAuth 2.0 Authorization Code Flow with PKCE support. +""" + +import base64 +import hashlib +import secrets +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from nextmcp.auth.core import AuthContext, AuthProvider, AuthResult, Permission + + +@dataclass +class OAuthConfig: + """OAuth provider configuration.""" + + client_id: str + client_secret: str | None = None # Optional for PKCE + authorization_url: str = "" + token_url: str = "" + redirect_uri: str = "http://localhost:8080/oauth/callback" + scope: list[str] = field(default_factory=list) # OAuth scopes to request + + +@dataclass +class PKCEChallenge: + """PKCE challenge data for OAuth 2.0.""" + + verifier: str + challenge: str + method: str = "S256" + + @classmethod + def generate(cls) -> "PKCEChallenge": + """ + Generate a new PKCE challenge. + + Returns: + PKCEChallenge with verifier and challenge + """ + # Generate cryptographically secure verifier (43-128 characters) + verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") + + # Create SHA256 challenge + challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("utf-8")).digest()) + .decode("utf-8") + .rstrip("=") + ) + + return cls(verifier=verifier, challenge=challenge, method="S256") + + +class OAuthProvider(AuthProvider, ABC): + """ + Base OAuth 2.0 provider with PKCE support. + + Implements Authorization Code Flow with optional PKCE. + Subclasses implement provider-specific details. + """ + + def __init__(self, config: OAuthConfig, **kwargs: Any): + """ + Initialize OAuth provider. + + Args: + config: OAuth configuration + **kwargs: Additional provider configuration + """ + super().__init__(**kwargs) + self.config = config + self._pending_auth: dict[str, PKCEChallenge] = {} # state -> PKCE + + def generate_authorization_url(self, state: str | None = None) -> dict[str, str]: + """ + Generate OAuth authorization URL with PKCE. + + Args: + state: Optional state parameter for CSRF protection + + Returns: + Dict with 'url', 'state', and 'verifier' (store securely!) + """ + if not state: + state = secrets.token_urlsafe(32) + + # Generate PKCE challenge + pkce = PKCEChallenge.generate() + self._pending_auth[state] = pkce + + params = { + "client_id": self.config.client_id, + "redirect_uri": self.config.redirect_uri, + "response_type": "code", + "state": state, + "code_challenge": pkce.challenge, + "code_challenge_method": pkce.method, + } + + if self.config.scope: + params["scope"] = " ".join(self.config.scope) + + # Add provider-specific parameters + params.update(self.get_additional_auth_params()) + + # Properly encode URL parameters + from urllib.parse import urlencode + + query_string = urlencode(params) + url = f"{self.config.authorization_url}?{query_string}" + + return { + "url": url, + "state": state, + "verifier": pkce.verifier, # Client must store this! + } + + async def exchange_code_for_token( + self, code: str, state: str, verifier: str | None = None + ) -> dict[str, Any]: + """ + Exchange authorization code for access token. + + Args: + code: Authorization code from OAuth callback + state: State parameter for CSRF protection + verifier: PKCE verifier (if not stored in provider) + + Returns: + Token response with access_token, refresh_token, etc. + + Raises: + ValueError: If state is invalid or token exchange fails + """ + import aiohttp + + # Get PKCE verifier + if verifier is None: + pkce = self._pending_auth.pop(state, None) + if not pkce: + raise ValueError("Invalid state or expired authorization") + verifier = pkce.verifier + + # Build token request + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.config.redirect_uri, + "client_id": self.config.client_id, + "code_verifier": verifier, + } + + # Add client secret if provided (confidential clients) + if self.config.client_secret: + data["client_secret"] = self.config.client_secret + + async with aiohttp.ClientSession() as session: + async with session.post(self.config.token_url, data=data) as resp: + if resp.status != 200: + # Try to get error details + try: + error_data = await resp.json() + except Exception: + error_data = await resp.text() + raise ValueError(f"Token exchange failed: {error_data}") + + # GitHub returns form-encoded, Google returns JSON + content_type = resp.headers.get("Content-Type", "") + if "application/json" in content_type: + return await resp.json() + else: + # Parse form-encoded response (GitHub uses this) + from urllib.parse import parse_qs + + text = await resp.text() + parsed = parse_qs(text) + # Convert lists to single values where appropriate + return {k: v[0] if len(v) == 1 else v for k, v in parsed.items()} + + async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]: + """ + Refresh an access token using a refresh token. + + Args: + refresh_token: The refresh token + + Returns: + New token response + + Raises: + ValueError: If token refresh fails + """ + import aiohttp + + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": self.config.client_id, + } + + if self.config.client_secret: + data["client_secret"] = self.config.client_secret + + async with aiohttp.ClientSession() as session: + async with session.post(self.config.token_url, data=data) as resp: + if resp.status != 200: + # Try to get error details + try: + error_data = await resp.json() + except Exception: + error_data = await resp.text() + raise ValueError(f"Token refresh failed: {error_data}") + + # GitHub returns form-encoded, Google returns JSON + content_type = resp.headers.get("Content-Type", "") + if "application/json" in content_type: + return await resp.json() + else: + # Parse form-encoded response (GitHub uses this) + from urllib.parse import parse_qs + + text = await resp.text() + parsed = parse_qs(text) + # Convert lists to single values where appropriate + return {k: v[0] if len(v) == 1 else v for k, v in parsed.items()} + + @abstractmethod + async def get_user_info(self, access_token: str) -> dict[str, Any]: + """ + Get user information from OAuth provider. + + Args: + access_token: OAuth access token + + Returns: + User information dictionary + + Raises: + ValueError: If user info retrieval fails + """ + pass + + @abstractmethod + def get_additional_auth_params(self) -> dict[str, str]: + """ + Get provider-specific authorization parameters. + + Returns: + Dictionary of additional parameters to add to auth URL + """ + return {} + + async def authenticate(self, credentials: dict[str, Any]) -> AuthResult: + """ + Authenticate using OAuth access token. + + Expected credentials: + { + "access_token": "oauth_access_token", + "refresh_token": "oauth_refresh_token", # optional + "scopes": ["scope1", "scope2"] # optional + } + + Args: + credentials: Authentication credentials + + Returns: + AuthResult with success status and auth context + """ + access_token = credentials.get("access_token") + if not access_token: + return AuthResult.failure("Missing access_token") + + try: + # Get user info from OAuth provider + user_info = await self.get_user_info(access_token) + + # Build auth context + context = AuthContext( + authenticated=True, + user_id=self.extract_user_id(user_info), + username=self.extract_username(user_info), + metadata={ + "oauth_provider": self.name, + "access_token": access_token, + "refresh_token": credentials.get("refresh_token"), + "user_info": user_info, + }, + ) + + # Add OAuth scopes as both scopes and permissions + # This maintains backward compatibility while enabling scope-specific features + for scope in credentials.get("scopes", []): + context.add_scope(scope) # Add as OAuth scope + context.add_permission( + Permission(scope) + ) # Also add as permission for backward compat + + return AuthResult.success_result(context) + + except Exception as e: + return AuthResult.failure(f"OAuth authentication failed: {e}") + + @abstractmethod + def extract_user_id(self, user_info: dict[str, Any]) -> str: + """ + Extract user ID from provider's user info. + + Args: + user_info: User information from OAuth provider + + Returns: + User ID string + """ + pass + + def extract_username(self, user_info: dict[str, Any]) -> str | None: + """ + Extract username from provider's user info. + + Args: + user_info: User information from OAuth provider + + Returns: + Username string or None + """ + return user_info.get("login") or user_info.get("email") diff --git a/nextmcp/auth/oauth_providers.py b/nextmcp/auth/oauth_providers.py new file mode 100644 index 0000000..fa9162c --- /dev/null +++ b/nextmcp/auth/oauth_providers.py @@ -0,0 +1,201 @@ +""" +Ready-to-use OAuth providers for common services. + +Provides GitHub and Google OAuth providers with sensible defaults. +""" + +from typing import Any + +from nextmcp.auth.oauth import OAuthConfig, OAuthProvider + + +class GitHubOAuthProvider(OAuthProvider): + """ + GitHub OAuth provider. + + Implements OAuth 2.0 for GitHub with standard scopes. + """ + + def __init__( + self, + client_id: str, + client_secret: str | None = None, + redirect_uri: str = "http://localhost:8080/oauth/callback", + scope: list[str] | None = None, + **kwargs: Any, + ): + """ + Initialize GitHub OAuth provider. + + Args: + client_id: GitHub OAuth app client ID + client_secret: GitHub OAuth app client secret (optional for PKCE) + redirect_uri: OAuth callback URI + scope: OAuth scopes to request (default: ["read:user"]) + **kwargs: Additional configuration + """ + config = OAuthConfig( + client_id=client_id, + client_secret=client_secret, + authorization_url="https://github.com/login/oauth/authorize", + token_url="https://github.com/login/oauth/access_token", + redirect_uri=redirect_uri, + scope=scope or ["read:user"], + ) + super().__init__(config, **kwargs) + + async def get_user_info(self, access_token: str) -> dict[str, Any]: + """ + Get GitHub user information. + + Args: + access_token: GitHub access token + + Returns: + User information dictionary + + Raises: + ValueError: If user info retrieval fails + """ + import aiohttp + + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + } + + async with aiohttp.ClientSession() as session: + async with session.get("https://api.github.com/user", headers=headers) as resp: + if resp.status != 200: + raise ValueError(f"Failed to get user info: {await resp.text()}") + return await resp.json() + + def extract_user_id(self, user_info: dict[str, Any]) -> str: + """ + Extract GitHub user ID. + + Args: + user_info: GitHub user information + + Returns: + User ID as string + """ + return str(user_info["id"]) + + def extract_username(self, user_info: dict[str, Any]) -> str | None: + """ + Extract GitHub username. + + Args: + user_info: GitHub user information + + Returns: + GitHub login username + """ + return user_info.get("login") + + def get_additional_auth_params(self) -> dict[str, str]: + """ + GitHub-specific authorization parameters. + + Returns: + Empty dict (GitHub doesn't need additional params) + """ + return {} + + +class GoogleOAuthProvider(OAuthProvider): + """ + Google OAuth provider. + + Implements OAuth 2.0 for Google with standard scopes. + """ + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str = "http://localhost:8080/oauth/callback", + scope: list[str] | None = None, + **kwargs: Any, + ): + """ + Initialize Google OAuth provider. + + Args: + client_id: Google OAuth app client ID + client_secret: Google OAuth app client secret + redirect_uri: OAuth callback URI + scope: OAuth scopes to request (default: ["openid", "email", "profile"]) + **kwargs: Additional configuration + """ + config = OAuthConfig( + client_id=client_id, + client_secret=client_secret, + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + redirect_uri=redirect_uri, + scope=scope or ["openid", "email", "profile"], + ) + super().__init__(config, **kwargs) + + async def get_user_info(self, access_token: str) -> dict[str, Any]: + """ + Get Google user information. + + Args: + access_token: Google access token + + Returns: + User information dictionary + + Raises: + ValueError: If user info retrieval fails + """ + import aiohttp + + headers = {"Authorization": f"Bearer {access_token}"} + + async with aiohttp.ClientSession() as session: + async with session.get( + "https://www.googleapis.com/oauth2/v2/userinfo", headers=headers + ) as resp: + if resp.status != 200: + raise ValueError(f"Failed to get user info: {await resp.text()}") + return await resp.json() + + def extract_user_id(self, user_info: dict[str, Any]) -> str: + """ + Extract Google user ID. + + Args: + user_info: Google user information + + Returns: + User ID as string + """ + return user_info["id"] + + def extract_username(self, user_info: dict[str, Any]) -> str | None: + """ + Extract Google email as username. + + Args: + user_info: Google user information + + Returns: + User's email address + """ + return user_info.get("email") + + def get_additional_auth_params(self) -> dict[str, str]: + """ + Google-specific authorization parameters. + + Returns: + Dict with access_type and prompt parameters for refresh tokens + """ + return { + "access_type": "offline", # Request refresh token + "prompt": "consent", + } diff --git a/nextmcp/auth/request_middleware.py b/nextmcp/auth/request_middleware.py new file mode 100644 index 0000000..71c3c02 --- /dev/null +++ b/nextmcp/auth/request_middleware.py @@ -0,0 +1,360 @@ +""" +Request-level auth enforcement middleware for NextMCP. + +This module provides middleware that intercepts ALL MCP requests and enforces +authentication and authorization automatically, without requiring decorators +on individual tools. + +This is the runtime enforcement layer that makes auth actually work in production. +""" + +import logging +from collections.abc import Callable +from typing import Any + +from nextmcp.auth.core import AuthContext, AuthProvider, AuthResult +from nextmcp.auth.errors import AuthenticationError, AuthorizationError +from nextmcp.auth.manifest import PermissionManifest +from nextmcp.protocol.auth_metadata import AuthMetadata, AuthRequirement +from nextmcp.session.session_store import SessionStore + +logger = logging.getLogger(__name__) + + +class AuthEnforcementMiddleware: + """ + Middleware that enforces authentication and authorization on every request. + + This middleware: + 1. Extracts auth credentials from request + 2. Validates tokens using the auth provider + 3. Loads session data + 4. Checks scopes and permissions + 5. Populates auth context + 6. Rejects unauthorized requests with structured errors + + Example: + # In your MCP server setup + auth_middleware = AuthEnforcementMiddleware( + provider=google_oauth, + session_store=MemorySessionStore(), + metadata=auth_metadata, + manifest=permission_manifest + ) + + # Apply to all requests + server.use(auth_middleware) + """ + + def __init__( + self, + provider: AuthProvider, + session_store: SessionStore | None = None, + metadata: AuthMetadata | None = None, + manifest: PermissionManifest | None = None, + credentials_key: str = "auth", + auto_refresh_tokens: bool = True, + ): + """ + Initialize auth enforcement middleware. + + Args: + provider: Auth provider for validation + session_store: Session storage (optional) + metadata: Auth metadata for requirement checking + manifest: Permission manifest for tool requirements + credentials_key: Key in request where credentials are found + auto_refresh_tokens: Automatically refresh expired tokens + """ + self.provider = provider + self.session_store = session_store + self.metadata = metadata or AuthMetadata() + self.manifest = manifest + self.credentials_key = credentials_key + self.auto_refresh_tokens = auto_refresh_tokens + + async def __call__( + self, + request: dict[str, Any], + handler: Callable, + ) -> Any: + """ + Process request with auth enforcement. + + Args: + request: MCP request dictionary + handler: Next middleware/handler in chain + + Returns: + Response from handler + + Raises: + AuthenticationError: If authentication fails + AuthorizationError: If authorization fails + """ + # Check if authentication is required + if self.metadata.requirement == AuthRequirement.NONE: + # No auth required, pass through + return await handler(request) + + # Extract credentials from request + credentials = request.get(self.credentials_key, {}) + + # If auth is optional and no credentials provided, allow request + if self.metadata.requirement == AuthRequirement.OPTIONAL and not credentials: + logger.debug("Optional auth: no credentials provided, allowing request") + return await handler(request) + + # Auth is required or credentials were provided + if not credentials: + raise AuthenticationError( + "Authentication required but no credentials provided", + required_scopes=self.metadata.required_scopes, + providers=self.metadata.providers, + ) + + # Authenticate using provider + auth_result = await self._authenticate(credentials, request) + + if not auth_result.success: + raise AuthenticationError( + auth_result.error or "Authentication failed", + required_scopes=self.metadata.required_scopes, + providers=self.metadata.providers, + ) + + auth_context = auth_result.context + + # Check authorization (scopes, permissions, manifest) + self._check_authorization(auth_context, request) + + # Inject auth context into request for handlers + request["_auth_context"] = auth_context + + # Call next handler + return await handler(request) + + async def _authenticate( + self, + credentials: dict[str, Any], + request: dict[str, Any], + ) -> AuthResult: + """ + Authenticate the request. + + Args: + credentials: Auth credentials + request: Full request data + + Returns: + AuthResult with success status and context + """ + # Extract access token + access_token = credentials.get("access_token") + if not access_token: + return AuthResult.failure("No access_token in credentials") + + # Check session store first + user_id = None + if self.session_store: + # Try to find user by token (this is a simple implementation) + # In production, you might want to decode JWT or lookup by token hash + for uid in self.session_store.list_users(): + session = self.session_store.load(uid) + if session and session.access_token == access_token: + user_id = uid + + # Check if token needs refresh + if self.auto_refresh_tokens and session.needs_refresh(): + logger.info(f"Token expiring soon for user {user_id}, refreshing...") + try: + await self._refresh_token(session) + except Exception as e: + logger.warning(f"Token refresh failed for {user_id}: {e}") + + # Check if token is expired + if session.is_expired(): + return AuthResult.failure("Access token expired") + + # Build auth context from session + auth_context = AuthContext( + authenticated=True, + user_id=session.user_id, + username=session.user_info.get("login") or session.user_info.get("email"), + metadata={ + "oauth_provider": session.provider, + "access_token": session.access_token, + "refresh_token": session.refresh_token, + "user_info": session.user_info, + }, + ) + + # Add scopes from session + for scope in session.scopes: + auth_context.add_scope(scope) + + return AuthResult.success_result(auth_context) + + # No session found, authenticate with provider + result = await self.provider.authenticate(credentials) + + # If successful and we have session store, save session + if result.success and self.session_store and result.context: + try: + from nextmcp.session.session_store import SessionData + + session = SessionData( + user_id=result.context.user_id, + access_token=access_token, + refresh_token=credentials.get("refresh_token"), + scopes=list(result.context.scopes), + user_info=result.context.metadata.get("user_info", {}), + provider=self.provider.name, + ) + self.session_store.save(session) + logger.info(f"Created new session for user: {result.context.user_id}") + except Exception as e: + logger.warning(f"Failed to save session: {e}") + + return result + + async def _refresh_token(self, session: "SessionData") -> None: + """ + Refresh an expired token. + + Args: + session: Session data with refresh token + + Raises: + ValueError: If refresh fails + """ + if not session.refresh_token: + raise ValueError("No refresh token available") + + # Import OAuth provider types + from nextmcp.auth.oauth import OAuthProvider + + if not isinstance(self.provider, OAuthProvider): + raise ValueError("Token refresh only supported for OAuth providers") + + # Refresh token + token_data = await self.provider.refresh_access_token(session.refresh_token) + + # Update session + if self.session_store: + import time + + session.access_token = token_data.get("access_token") + if "refresh_token" in token_data: + session.refresh_token = token_data["refresh_token"] + if "expires_in" in token_data: + session.expires_at = time.time() + token_data["expires_in"] + + self.session_store.save(session) + logger.info(f"Refreshed token for user: {session.user_id}") + + def _check_authorization( + self, + auth_context: AuthContext, + request: dict[str, Any], + ) -> None: + """ + Check if user is authorized for this request. + + Args: + auth_context: Authenticated user context + request: Request data + + Raises: + AuthorizationError: If user lacks required authorization + """ + # Extract tool name from request (MCP format) + tool_name = request.get("params", {}).get("name") + if not tool_name: + # Not a tool call, allow + return + + # Check required scopes from metadata + if self.metadata.required_scopes: + has_all_scopes = all( + auth_context.has_scope(scope) for scope in self.metadata.required_scopes + ) + if not has_all_scopes: + missing_scopes = [ + scope + for scope in self.metadata.required_scopes + if not auth_context.has_scope(scope) + ] + raise AuthorizationError( + f"Missing required scopes: {', '.join(missing_scopes)}", + required=missing_scopes, + user_id=auth_context.user_id, + ) + + # Check manifest if provided + if self.manifest: + allowed, error_message = self.manifest.check_tool_access(tool_name, auth_context) + if not allowed: + # Get tool definition for detailed error + tool_def = self.manifest.tools.get(tool_name) + + from nextmcp.auth.errors import ManifestViolationError + + raise ManifestViolationError( + message=error_message or "Access denied by permission manifest", + tool_name=tool_name, + required_roles=tool_def.roles if tool_def else [], + required_permissions=tool_def.permissions if tool_def else [], + required_scopes=tool_def.scopes if tool_def else [], + user_id=auth_context.user_id, + auth_context=auth_context, + ) + + logger.debug(f"Authorization check passed for {auth_context.user_id} on {tool_name}") + + +class SessionData: + """Forward declaration for type hints (actual class in session_store.py).""" + + pass + + +def create_auth_middleware( + provider: AuthProvider, + requirement: AuthRequirement = AuthRequirement.REQUIRED, + session_store: SessionStore | None = None, + manifest: PermissionManifest | None = None, + required_scopes: list[str] | None = None, +) -> AuthEnforcementMiddleware: + """ + Helper function to create auth enforcement middleware. + + Args: + provider: Auth provider + requirement: Auth requirement level + session_store: Session storage (optional) + manifest: Permission manifest (optional) + required_scopes: Required OAuth scopes (optional) + + Returns: + Configured AuthEnforcementMiddleware + + Example: + middleware = create_auth_middleware( + provider=github_oauth, + requirement=AuthRequirement.REQUIRED, + session_store=MemorySessionStore(), + required_scopes=["repo", "user"] + ) + """ + metadata = AuthMetadata( + requirement=requirement, + required_scopes=required_scopes or [], + ) + + return AuthEnforcementMiddleware( + provider=provider, + session_store=session_store, + metadata=metadata, + manifest=manifest, + ) diff --git a/nextmcp/protocol/__init__.py b/nextmcp/protocol/__init__.py new file mode 100644 index 0000000..eee7dc6 --- /dev/null +++ b/nextmcp/protocol/__init__.py @@ -0,0 +1,19 @@ +""" +NextMCP Protocol Extensions. + +This module defines protocol-level extensions for MCP servers. +""" + +from nextmcp.protocol.auth_metadata import ( + AuthFlowType, + AuthMetadata, + AuthProviderMetadata, + AuthRequirement, +) + +__all__ = [ + "AuthMetadata", + "AuthProviderMetadata", + "AuthRequirement", + "AuthFlowType", +] diff --git a/nextmcp/protocol/auth_metadata.py b/nextmcp/protocol/auth_metadata.py new file mode 100644 index 0000000..0ecd692 --- /dev/null +++ b/nextmcp/protocol/auth_metadata.py @@ -0,0 +1,339 @@ +""" +Auth Metadata Protocol for NextMCP. + +This module defines the protocol-level metadata that MCP servers use to announce +their authentication and authorization requirements to clients/hosts. + +This is the critical piece that allows hosts (like Claude Desktop, Cursor, etc.) +to discover what auth a server needs and present the appropriate UI. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class AuthFlowType(str, Enum): + """Supported authentication flow types.""" + + OAUTH2_PKCE = "oauth2-pkce" + OAUTH2_CLIENT_CREDENTIALS = "oauth2-client-credentials" + API_KEY = "api-key" + JWT = "jwt" + BASIC = "basic" + CUSTOM = "custom" + + +class AuthRequirement(str, Enum): + """Authentication requirement levels.""" + + REQUIRED = "required" # All requests must be authenticated + OPTIONAL = "optional" # Authentication enhances functionality but isn't required + NONE = "none" # No authentication + + +@dataclass +class AuthProviderMetadata: + """ + Metadata for a single OAuth/auth provider. + + This describes one authentication provider (e.g., Google, GitHub) + that the server supports. + """ + + name: str # Provider name: "google", "github", etc. + type: str # Provider type: "oauth2", "api-key", etc. + flows: list[AuthFlowType] # Supported flows + authorization_url: str | None = None # OAuth authorization endpoint + token_url: str | None = None # OAuth token endpoint + scopes: list[str] = field(default_factory=list) # Available OAuth scopes + supports_refresh: bool = False # Whether refresh tokens are supported + supports_pkce: bool = True # Whether PKCE is supported + metadata_url: str | None = None # Well-known metadata URL + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "name": self.name, + "type": self.type, + "flows": [flow.value for flow in self.flows], + "authorization_url": self.authorization_url, + "token_url": self.token_url, + "scopes": self.scopes, + "supports_refresh": self.supports_refresh, + "supports_pkce": self.supports_pkce, + "metadata_url": self.metadata_url, + } + + +@dataclass +class AuthMetadata: + """ + Complete authentication metadata for an MCP server. + + This is the top-level structure that servers expose to announce their + authentication requirements, supported providers, scopes, and permissions. + + Example JSON representation: + { + "auth": { + "requirement": "required", + "providers": [ + { + "name": "google", + "type": "oauth2", + "flows": ["oauth2-pkce"], + "authorization_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "scopes": ["profile", "email", "drive.readonly"], + "supports_refresh": true, + "supports_pkce": true + } + ], + "required_scopes": ["profile", "email"], + "optional_scopes": ["drive.readonly", "gmail.readonly"], + "permissions": ["file.read", "email.send"], + "supports_multi_user": true, + "session_management": "server-side" + } + } + """ + + requirement: AuthRequirement = AuthRequirement.NONE + providers: list[AuthProviderMetadata] = field(default_factory=list) + required_scopes: list[str] = field(default_factory=list) # Minimum scopes needed + optional_scopes: list[str] = field(default_factory=list) # Additional scopes + permissions: list[str] = field(default_factory=list) # Custom permissions + roles: list[str] = field(default_factory=list) # Available roles + supports_multi_user: bool = False # Multi-user/multi-tenant support + session_management: str = "server-side" # "server-side", "client-side", "stateless" + token_refresh_enabled: bool = False # Server handles token refresh + error_codes: dict[str, str] = field(default_factory=dict) # Auth error code docs + + def to_dict(self) -> dict[str, Any]: + """ + Convert to dictionary for JSON serialization. + + Returns: + Dictionary representation suitable for JSON serialization + """ + return { + "requirement": self.requirement.value, + "providers": [provider.to_dict() for provider in self.providers], + "required_scopes": self.required_scopes, + "optional_scopes": self.optional_scopes, + "permissions": self.permissions, + "roles": self.roles, + "supports_multi_user": self.supports_multi_user, + "session_management": self.session_management, + "token_refresh_enabled": self.token_refresh_enabled, + "error_codes": self.error_codes, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AuthMetadata": + """ + Create AuthMetadata from dictionary. + + Args: + data: Dictionary representation + + Returns: + AuthMetadata instance + """ + providers = [ + AuthProviderMetadata( + name=p["name"], + type=p["type"], + flows=[AuthFlowType(f) for f in p.get("flows", [])], + authorization_url=p.get("authorization_url"), + token_url=p.get("token_url"), + scopes=p.get("scopes", []), + supports_refresh=p.get("supports_refresh", False), + supports_pkce=p.get("supports_pkce", True), + metadata_url=p.get("metadata_url"), + ) + for p in data.get("providers", []) + ] + + return cls( + requirement=AuthRequirement(data.get("requirement", "none")), + providers=providers, + required_scopes=data.get("required_scopes", []), + optional_scopes=data.get("optional_scopes", []), + permissions=data.get("permissions", []), + roles=data.get("roles", []), + supports_multi_user=data.get("supports_multi_user", False), + session_management=data.get("session_management", "server-side"), + token_refresh_enabled=data.get("token_refresh_enabled", False), + error_codes=data.get("error_codes", {}), + ) + + def add_provider( + self, + name: str, + type: str, + flows: list[AuthFlowType], + authorization_url: str | None = None, + token_url: str | None = None, + scopes: list[str] | None = None, + supports_refresh: bool = False, + supports_pkce: bool = True, + ) -> None: + """ + Add an authentication provider. + + Args: + name: Provider name (e.g., "google", "github") + type: Provider type (e.g., "oauth2", "api-key") + flows: Supported authentication flows + authorization_url: OAuth authorization endpoint + token_url: OAuth token endpoint + scopes: Available scopes + supports_refresh: Whether refresh tokens are supported + supports_pkce: Whether PKCE is supported + """ + provider = AuthProviderMetadata( + name=name, + type=type, + flows=flows, + authorization_url=authorization_url, + token_url=token_url, + scopes=scopes or [], + supports_refresh=supports_refresh, + supports_pkce=supports_pkce, + ) + self.providers.append(provider) + + def add_required_scope(self, scope: str) -> None: + """Add a required OAuth scope.""" + if scope not in self.required_scopes: + self.required_scopes.append(scope) + + def add_optional_scope(self, scope: str) -> None: + """Add an optional OAuth scope.""" + if scope not in self.optional_scopes: + self.optional_scopes.append(scope) + + def add_permission(self, permission: str) -> None: + """Add a custom permission.""" + if permission not in self.permissions: + self.permissions.append(permission) + + def add_role(self, role: str) -> None: + """Add a role.""" + if role not in self.roles: + self.roles.append(role) + + def validate(self) -> tuple[bool, list[str]]: + """ + Validate the metadata configuration. + + Returns: + Tuple of (is_valid, list_of_errors) + """ + errors = [] + + if self.requirement == AuthRequirement.REQUIRED and not self.providers: + errors.append("Authentication is required but no providers configured") + + for provider in self.providers: + if provider.type == "oauth2": + if not provider.authorization_url: + errors.append(f"Provider '{provider.name}' missing authorization_url") + if not provider.token_url: + errors.append(f"Provider '{provider.name}' missing token_url") + + # Check for scope conflicts + scope_overlap = set(self.required_scopes) & set(self.optional_scopes) + if scope_overlap: + errors.append(f"Scopes cannot be both required and optional: {scope_overlap}") + + return len(errors) == 0, errors + + +# JSON Schema for validation (can be used by hosts) +AUTH_METADATA_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "NextMCP Auth Metadata", + "description": "Authentication metadata for MCP servers", + "type": "object", + "properties": { + "requirement": { + "type": "string", + "enum": ["required", "optional", "none"], + "description": "Authentication requirement level", + }, + "providers": { + "type": "array", + "description": "List of supported authentication providers", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "type": {"type": "string"}, + "flows": { + "type": "array", + "items": { + "type": "string", + "enum": [ + "oauth2-pkce", + "oauth2-client-credentials", + "api-key", + "jwt", + "basic", + "custom", + ], + }, + }, + "authorization_url": {"type": ["string", "null"]}, + "token_url": {"type": ["string", "null"]}, + "scopes": {"type": "array", "items": {"type": "string"}}, + "supports_refresh": {"type": "boolean"}, + "supports_pkce": {"type": "boolean"}, + "metadata_url": {"type": ["string", "null"]}, + }, + "required": ["name", "type", "flows"], + }, + }, + "required_scopes": { + "type": "array", + "items": {"type": "string"}, + "description": "Minimum required OAuth scopes", + }, + "optional_scopes": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional OAuth scopes that enhance functionality", + }, + "permissions": { + "type": "array", + "items": {"type": "string"}, + "description": "Custom permission strings", + }, + "roles": { + "type": "array", + "items": {"type": "string"}, + "description": "Available user roles", + }, + "supports_multi_user": { + "type": "boolean", + "description": "Whether server supports multiple users", + }, + "session_management": { + "type": "string", + "enum": ["server-side", "client-side", "stateless"], + "description": "Session management strategy", + }, + "token_refresh_enabled": { + "type": "boolean", + "description": "Whether server handles token refresh", + }, + "error_codes": { + "type": "object", + "description": "Documentation for auth error codes", + "additionalProperties": {"type": "string"}, + }, + }, + "required": ["requirement"], +} diff --git a/nextmcp/session/__init__.py b/nextmcp/session/__init__.py new file mode 100644 index 0000000..46ddf87 --- /dev/null +++ b/nextmcp/session/__init__.py @@ -0,0 +1,20 @@ +""" +Session management for NextMCP. + +This module provides session storage for OAuth tokens, user identity, +and session state management. +""" + +from nextmcp.session.session_store import ( + FileSessionStore, + MemorySessionStore, + SessionData, + SessionStore, +) + +__all__ = [ + "SessionStore", + "SessionData", + "MemorySessionStore", + "FileSessionStore", +] diff --git a/nextmcp/session/session_store.py b/nextmcp/session/session_store.py new file mode 100644 index 0000000..61571ec --- /dev/null +++ b/nextmcp/session/session_store.py @@ -0,0 +1,375 @@ +""" +Session store implementations for NextMCP. + +Provides pluggable session storage for OAuth tokens, user identity, +and session state. Supports multiple backends (memory, file, Redis, etc.). +""" + +import json +import logging +import threading +import time +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class SessionData: + """ + Session data stored for each user. + + Contains OAuth tokens, user information, and session metadata. + """ + + user_id: str + access_token: str | None = None + refresh_token: str | None = None + token_type: str = "Bearer" + expires_at: float | None = None # Unix timestamp + scopes: list[str] = field(default_factory=list) + user_info: dict[str, Any] = field(default_factory=dict) + provider: str | None = None # OAuth provider name + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + metadata: dict[str, Any] = field(default_factory=dict) # Custom data + + def is_expired(self) -> bool: + """Check if access token is expired.""" + if self.expires_at is None: + return False + return time.time() >= self.expires_at + + def needs_refresh(self, buffer_seconds: int = 300) -> bool: + """ + Check if token needs refreshing. + + Args: + buffer_seconds: Refresh if expiring within this many seconds (default 5 min) + + Returns: + True if token should be refreshed + """ + if self.expires_at is None: + return False + return time.time() >= (self.expires_at - buffer_seconds) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SessionData": + """Create SessionData from dictionary.""" + return cls(**data) + + +class SessionStore(ABC): + """ + Abstract base class for session storage. + + Subclasses implement different storage backends (memory, file, Redis, etc.). + """ + + @abstractmethod + def save(self, session: SessionData) -> None: + """ + Save session data. + + Args: + session: Session data to save + """ + pass + + @abstractmethod + def load(self, user_id: str) -> SessionData | None: + """ + Load session data for a user. + + Args: + user_id: User identifier + + Returns: + SessionData if found, None otherwise + """ + pass + + @abstractmethod + def delete(self, user_id: str) -> bool: + """ + Delete session data for a user. + + Args: + user_id: User identifier + + Returns: + True if session was deleted, False if not found + """ + pass + + @abstractmethod + def exists(self, user_id: str) -> bool: + """ + Check if session exists for a user. + + Args: + user_id: User identifier + + Returns: + True if session exists + """ + pass + + @abstractmethod + def list_users(self) -> list[str]: + """ + List all user IDs with active sessions. + + Returns: + List of user IDs + """ + pass + + @abstractmethod + def clear_all(self) -> int: + """ + Clear all sessions. + + Returns: + Number of sessions deleted + """ + pass + + def update_tokens( + self, + user_id: str, + access_token: str, + refresh_token: str | None = None, + expires_in: int | None = None, + ) -> None: + """ + Update tokens for an existing session. + + Args: + user_id: User identifier + access_token: New access token + refresh_token: New refresh token (optional) + expires_in: Token expiration in seconds (optional) + """ + session = self.load(user_id) + if not session: + raise ValueError(f"No session found for user: {user_id}") + + session.access_token = access_token + if refresh_token: + session.refresh_token = refresh_token + if expires_in: + session.expires_at = time.time() + expires_in + session.updated_at = time.time() + + self.save(session) + + +class MemorySessionStore(SessionStore): + """ + In-memory session storage. + + Sessions are stored in RAM and lost when the process restarts. + Useful for development and testing. + + Thread-safe with locking. + """ + + def __init__(self): + """Initialize memory session store.""" + self._sessions: dict[str, SessionData] = {} + self._lock = threading.RLock() + + def save(self, session: SessionData) -> None: + """Save session to memory.""" + with self._lock: + session.updated_at = time.time() + self._sessions[session.user_id] = session + logger.debug(f"Saved session for user: {session.user_id}") + + def load(self, user_id: str) -> SessionData | None: + """Load session from memory.""" + with self._lock: + session = self._sessions.get(user_id) + if session: + logger.debug(f"Loaded session for user: {user_id}") + return session + + def delete(self, user_id: str) -> bool: + """Delete session from memory.""" + with self._lock: + if user_id in self._sessions: + del self._sessions[user_id] + logger.debug(f"Deleted session for user: {user_id}") + return True + return False + + def exists(self, user_id: str) -> bool: + """Check if session exists in memory.""" + with self._lock: + return user_id in self._sessions + + def list_users(self) -> list[str]: + """List all user IDs in memory.""" + with self._lock: + return list(self._sessions.keys()) + + def clear_all(self) -> int: + """Clear all sessions from memory.""" + with self._lock: + count = len(self._sessions) + self._sessions.clear() + logger.info(f"Cleared {count} sessions from memory") + return count + + def cleanup_expired(self) -> int: + """ + Remove expired sessions from memory. + + Returns: + Number of expired sessions removed + """ + with self._lock: + expired = [ + user_id for user_id, session in self._sessions.items() if session.is_expired() + ] + for user_id in expired: + del self._sessions[user_id] + if expired: + logger.info(f"Cleaned up {len(expired)} expired sessions") + return len(expired) + + +class FileSessionStore(SessionStore): + """ + File-based session storage. + + Sessions are stored as JSON files in a directory. + Persists across process restarts. + + Thread-safe with locking. + """ + + def __init__(self, directory: str | Path = ".nextmcp_sessions"): + """ + Initialize file session store. + + Args: + directory: Directory to store session files (default: .nextmcp_sessions) + """ + self.directory = Path(directory) + self.directory.mkdir(parents=True, exist_ok=True) + self._lock = threading.RLock() + logger.info(f"File session store initialized at: {self.directory}") + + def _get_path(self, user_id: str) -> Path: + """Get file path for user session.""" + # Sanitize user_id for filename (replace invalid chars) + safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in user_id) + return self.directory / f"session_{safe_id}.json" + + def save(self, session: SessionData) -> None: + """Save session to file.""" + with self._lock: + session.updated_at = time.time() + path = self._get_path(session.user_id) + try: + with open(path, "w") as f: + json.dump(session.to_dict(), f, indent=2) + logger.debug(f"Saved session for user: {session.user_id}") + except Exception as e: + logger.error(f"Failed to save session for {session.user_id}: {e}") + raise + + def load(self, user_id: str) -> SessionData | None: + """Load session from file.""" + with self._lock: + path = self._get_path(user_id) + if not path.exists(): + return None + + try: + with open(path) as f: + data = json.load(f) + session = SessionData.from_dict(data) + logger.debug(f"Loaded session for user: {user_id}") + return session + except Exception as e: + logger.error(f"Failed to load session for {user_id}: {e}") + return None + + def delete(self, user_id: str) -> bool: + """Delete session file.""" + with self._lock: + path = self._get_path(user_id) + if path.exists(): + try: + path.unlink() + logger.debug(f"Deleted session for user: {user_id}") + return True + except Exception as e: + logger.error(f"Failed to delete session for {user_id}: {e}") + return False + return False + + def exists(self, user_id: str) -> bool: + """Check if session file exists.""" + with self._lock: + return self._get_path(user_id).exists() + + def list_users(self) -> list[str]: + """List all user IDs with session files.""" + with self._lock: + users = [] + for path in self.directory.glob("session_*.json"): + try: + with open(path) as f: + data = json.load(f) + users.append(data["user_id"]) + except Exception as e: + logger.warning(f"Failed to read session file {path}: {e}") + return users + + def clear_all(self) -> int: + """Delete all session files.""" + with self._lock: + count = 0 + for path in self.directory.glob("session_*.json"): + try: + path.unlink() + count += 1 + except Exception as e: + logger.error(f"Failed to delete session file {path}: {e}") + logger.info(f"Cleared {count} sessions from file store") + return count + + def cleanup_expired(self) -> int: + """ + Remove expired session files. + + Returns: + Number of expired sessions removed + """ + with self._lock: + count = 0 + for path in self.directory.glob("session_*.json"): + try: + with open(path) as f: + data = json.load(f) + session = SessionData.from_dict(data) + if session.is_expired(): + path.unlink() + count += 1 + except Exception as e: + logger.warning(f"Failed to check/delete expired session {path}: {e}") + if count: + logger.info(f"Cleaned up {count} expired sessions from file store") + return count diff --git a/pyproject.toml b/pyproject.toml index 3bf3330..652fbdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ dependencies = [ "fastmcp>=0.1.0", + "aiohttp>=3.8.0", ] [project.optional-dependencies] @@ -84,6 +85,7 @@ addopts = [ ] markers = [ "asyncio: marks tests as async (using pytest-asyncio)", + "integration: marks tests as integration tests (require real OAuth credentials)", ] [tool.black] diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh new file mode 100755 index 0000000..0cf7d6a --- /dev/null +++ b/scripts/setup_env.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Setup script for OAuth integration testing environment + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +ENV_FILE="$PROJECT_ROOT/.env" +ENV_EXAMPLE="$PROJECT_ROOT/.env.example" + +echo "==========================================" +echo "NextMCP OAuth Environment Setup" +echo "==========================================" +echo + +# Check if .env already exists +if [ -f "$ENV_FILE" ]; then + echo "⚠️ .env file already exists at: $ENV_FILE" + echo + read -p "Do you want to overwrite it? (y/N) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Aborted. Your existing .env file was not modified." + exit 0 + fi +fi + +# Copy example file +echo "Creating .env file from template..." +cp "$ENV_EXAMPLE" "$ENV_FILE" +echo "✓ Created .env file at: $ENV_FILE" +echo + +# Provide instructions +echo "==========================================" +echo "Next Steps:" +echo "==========================================" +echo +echo "1. Get OAuth credentials:" +echo " • GitHub: https://github.com/settings/developers" +echo " • Google: https://console.cloud.google.com" +echo +echo "2. Obtain access tokens using the helper script:" +echo " python examples/auth/oauth_token_helper.py" +echo +echo "3. Edit .env file and fill in your credentials:" +echo " ${EDITOR:-nano} .env" +echo +echo "4. Load environment variables:" +echo " export \$(cat .env | grep -v '^#' | xargs)" +echo +echo "5. Verify setup:" +echo " echo \$GITHUB_CLIENT_ID" +echo " echo \$GITHUB_ACCESS_TOKEN" +echo +echo "6. Run integration tests:" +echo " pytest tests/test_oauth_integration.py -v -m integration" +echo +echo "==========================================" +echo "For detailed setup instructions, see:" +echo "docs/OAUTH_TESTING_SETUP.md" +echo "==========================================" diff --git a/tests/test_auth_errors.py b/tests/test_auth_errors.py new file mode 100644 index 0000000..45d0fcd --- /dev/null +++ b/tests/test_auth_errors.py @@ -0,0 +1,270 @@ +""" +Tests for specialized authentication error types. + +Tests custom exceptions for OAuth, scopes, and manifest violations. +""" + +import pytest + +from nextmcp.auth.core import AuthContext, Permission, Role +from nextmcp.auth.errors import ( + ManifestViolationError, + OAuthRequiredError, + ScopeInsufficientError, +) + + +class TestOAuthRequiredError: + """Tests for OAuthRequiredError exception.""" + + def test_oauth_required_error_basic(self): + """Test creating OAuthRequiredError with basic message.""" + error = OAuthRequiredError("OAuth authentication required") + + assert str(error) == "OAuth authentication required" + assert isinstance(error, Exception) + + def test_oauth_required_error_with_provider(self): + """Test OAuthRequiredError with provider information.""" + error = OAuthRequiredError( + "OAuth authentication required", + provider="github", + scopes=["read:user", "repo:read"], + ) + + assert error.provider == "github" + assert error.scopes == ["read:user", "repo:read"] + assert "OAuth authentication required" in str(error) + + def test_oauth_required_error_with_authorization_url(self): + """Test OAuthRequiredError with authorization URL.""" + error = OAuthRequiredError( + "OAuth required", + provider="google", + authorization_url="https://accounts.google.com/o/oauth2/v2/auth?...", + ) + + assert error.authorization_url == "https://accounts.google.com/o/oauth2/v2/auth?..." + assert error.provider == "google" + + def test_oauth_required_error_attributes(self): + """Test all OAuthRequiredError attributes.""" + error = OAuthRequiredError( + message="Please authenticate", + provider="github", + scopes=["repo:write"], + authorization_url="https://github.com/login/oauth/authorize", + user_id=None, + ) + + assert error.message == "Please authenticate" + assert error.provider == "github" + assert error.scopes == ["repo:write"] + assert error.authorization_url == "https://github.com/login/oauth/authorize" + assert error.user_id is None + + +class TestScopeInsufficientError: + """Tests for ScopeInsufficientError exception.""" + + def test_scope_insufficient_error_basic(self): + """Test creating ScopeInsufficientError with basic message.""" + error = ScopeInsufficientError("Insufficient OAuth scopes") + + assert str(error) == "Insufficient OAuth scopes" + assert isinstance(error, Exception) + + def test_scope_insufficient_error_with_required_scopes(self): + """Test ScopeInsufficientError with required scopes.""" + error = ScopeInsufficientError( + "Missing required scopes", + required_scopes=["repo:write", "admin:org"], + current_scopes=["repo:read"], + ) + + assert error.required_scopes == ["repo:write", "admin:org"] + assert error.current_scopes == ["repo:read"] + assert "Missing required scopes" in str(error) + + def test_scope_insufficient_error_with_user_id(self): + """Test ScopeInsufficientError with user identification.""" + error = ScopeInsufficientError( + "Need admin scope", + required_scopes=["admin:all"], + current_scopes=["read:all"], + user_id="user123", + ) + + assert error.user_id == "user123" + assert error.required_scopes == ["admin:all"] + + def test_scope_insufficient_error_missing_context(self): + """Test ScopeInsufficientError when missing scope data.""" + error = ScopeInsufficientError( + "Scopes required", + required_scopes=["write:data"], + current_scopes=[], + ) + + assert error.required_scopes == ["write:data"] + assert error.current_scopes == [] + + +class TestManifestViolationError: + """Tests for ManifestViolationError exception.""" + + def test_manifest_violation_error_basic(self): + """Test creating ManifestViolationError with basic message.""" + error = ManifestViolationError("Manifest access denied") + + assert str(error) == "Manifest access denied" + assert isinstance(error, Exception) + + def test_manifest_violation_error_with_tool_name(self): + """Test ManifestViolationError with tool name.""" + error = ManifestViolationError( + "Access denied to tool", + tool_name="delete_database", + ) + + assert error.tool_name == "delete_database" + assert "Access denied to tool" in str(error) + + def test_manifest_violation_error_with_requirements(self): + """Test ManifestViolationError with requirement details.""" + error = ManifestViolationError( + "Missing required role", + tool_name="admin_panel", + required_roles=["admin", "superuser"], + required_permissions=["admin:all"], + required_scopes=["admin:full"], + ) + + assert error.tool_name == "admin_panel" + assert error.required_roles == ["admin", "superuser"] + assert error.required_permissions == ["admin:all"] + assert error.required_scopes == ["admin:full"] + + def test_manifest_violation_error_with_user_context(self): + """Test ManifestViolationError with user context.""" + error = ManifestViolationError( + "Unauthorized access attempt", + tool_name="sensitive_operation", + user_id="user456", + auth_context=AuthContext( + authenticated=True, + user_id="user456", + username="testuser", + ), + ) + + assert error.user_id == "user456" + assert error.auth_context is not None + assert error.auth_context.user_id == "user456" + + def test_manifest_violation_error_all_attributes(self): + """Test ManifestViolationError with all attributes.""" + context = AuthContext(authenticated=True, user_id="user789") + context.add_role(Role("viewer")) + + error = ManifestViolationError( + message="Complete denial", + tool_name="dangerous_tool", + required_roles=["admin"], + required_permissions=["write:all"], + required_scopes=["full:access"], + user_id="user789", + auth_context=context, + ) + + assert error.message == "Complete denial" + assert error.tool_name == "dangerous_tool" + assert error.required_roles == ["admin"] + assert error.required_permissions == ["write:all"] + assert error.required_scopes == ["full:access"] + assert error.user_id == "user789" + assert error.auth_context == context + + +class TestErrorHierarchy: + """Tests for error type hierarchy and inheritance.""" + + def test_all_errors_are_exceptions(self): + """Test that all error types inherit from Exception.""" + oauth_error = OAuthRequiredError("test") + scope_error = ScopeInsufficientError("test") + manifest_error = ManifestViolationError("test") + + assert isinstance(oauth_error, Exception) + assert isinstance(scope_error, Exception) + assert isinstance(manifest_error, Exception) + + def test_error_messages_are_strings(self): + """Test that all errors convert to string properly.""" + errors = [ + OAuthRequiredError("OAuth needed"), + ScopeInsufficientError("Scope missing"), + ManifestViolationError("Manifest violation"), + ] + + for error in errors: + assert isinstance(str(error), str) + assert len(str(error)) > 0 + + def test_errors_can_be_raised_and_caught(self): + """Test that errors can be raised and caught.""" + with pytest.raises(OAuthRequiredError) as exc_info: + raise OAuthRequiredError("Test OAuth error") + assert "Test OAuth error" in str(exc_info.value) + + with pytest.raises(ScopeInsufficientError) as exc_info: + raise ScopeInsufficientError("Test scope error") + assert "Test scope error" in str(exc_info.value) + + with pytest.raises(ManifestViolationError) as exc_info: + raise ManifestViolationError("Test manifest error") + assert "Test manifest error" in str(exc_info.value) + + +class TestErrorUsagePatterns: + """Tests for common error usage patterns.""" + + def test_oauth_error_for_missing_token(self): + """Test using OAuthRequiredError when access token is missing.""" + error = OAuthRequiredError( + "Access token required for this operation", + provider="github", + scopes=["repo:read"], + authorization_url="https://github.com/login/oauth/authorize?client_id=...", + ) + + assert error.provider == "github" + assert "repo:read" in error.scopes + + def test_scope_error_for_insufficient_permissions(self): + """Test using ScopeInsufficientError when scopes are insufficient.""" + error = ScopeInsufficientError( + "This operation requires write access", + required_scopes=["repo:write"], + current_scopes=["repo:read"], + user_id="user123", + ) + + assert "repo:write" in error.required_scopes + assert "repo:read" in error.current_scopes + + def test_manifest_error_for_tool_access_denial(self): + """Test using ManifestViolationError when tool access is denied.""" + context = AuthContext(authenticated=True, user_id="user456") + context.add_role(Role("viewer")) + + error = ManifestViolationError( + "Tool requires admin role", + tool_name="delete_all_users", + required_roles=["admin"], + user_id="user456", + auth_context=context, + ) + + assert error.tool_name == "delete_all_users" + assert "admin" in error.required_roles diff --git a/tests/test_auth_metadata.py b/tests/test_auth_metadata.py new file mode 100644 index 0000000..f9830a3 --- /dev/null +++ b/tests/test_auth_metadata.py @@ -0,0 +1,333 @@ +""" +Tests for Auth Metadata Protocol. + +Tests the protocol-level metadata system that allows MCP servers to announce +their authentication requirements to hosts. +""" + +import pytest + +from nextmcp.protocol.auth_metadata import ( + AUTH_METADATA_SCHEMA, + AuthFlowType, + AuthMetadata, + AuthProviderMetadata, + AuthRequirement, +) + + +class TestAuthProviderMetadata: + """Test AuthProviderMetadata functionality.""" + + def test_create_oauth_provider(self): + """Test creating OAuth provider metadata.""" + provider = AuthProviderMetadata( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["profile", "email"], + supports_refresh=True, + supports_pkce=True, + ) + + assert provider.name == "google" + assert provider.type == "oauth2" + assert AuthFlowType.OAUTH2_PKCE in provider.flows + assert provider.supports_refresh is True + assert "email" in provider.scopes + + def test_provider_to_dict(self): + """Test serializing provider to dictionary.""" + provider = AuthProviderMetadata( + name="github", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://github.com/login/oauth/authorize", + token_url="https://github.com/login/oauth/access_token", + scopes=["repo", "user"], + ) + + data = provider.to_dict() + + assert data["name"] == "github" + assert data["type"] == "oauth2" + assert "oauth2-pkce" in data["flows"] + assert data["scopes"] == ["repo", "user"] + + def test_api_key_provider(self): + """Test creating API key provider metadata.""" + provider = AuthProviderMetadata( + name="custom-api", + type="api-key", + flows=[AuthFlowType.API_KEY], + ) + + assert provider.name == "custom-api" + assert AuthFlowType.API_KEY in provider.flows + assert provider.supports_pkce is True # Default + + +class TestAuthMetadata: + """Test AuthMetadata functionality.""" + + def test_create_empty_metadata(self): + """Test creating empty auth metadata.""" + metadata = AuthMetadata() + + assert metadata.requirement == AuthRequirement.NONE + assert len(metadata.providers) == 0 + assert len(metadata.required_scopes) == 0 + + def test_create_required_auth(self): + """Test creating metadata with required auth.""" + metadata = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + required_scopes=["profile", "email"], + ) + + assert metadata.requirement == AuthRequirement.REQUIRED + assert "profile" in metadata.required_scopes + assert "email" in metadata.required_scopes + + def test_add_provider(self): + """Test adding a provider.""" + metadata = AuthMetadata() + metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["profile", "email"], + supports_refresh=True, + ) + + assert len(metadata.providers) == 1 + assert metadata.providers[0].name == "google" + assert metadata.providers[0].supports_refresh is True + + def test_add_scopes(self): + """Test adding scopes.""" + metadata = AuthMetadata() + + metadata.add_required_scope("profile") + metadata.add_required_scope("email") + metadata.add_optional_scope("drive.readonly") + + assert "profile" in metadata.required_scopes + assert "email" in metadata.required_scopes + assert "drive.readonly" in metadata.optional_scopes + + def test_add_permissions(self): + """Test adding permissions.""" + metadata = AuthMetadata() + + metadata.add_permission("file.read") + metadata.add_permission("file.write") + + assert "file.read" in metadata.permissions + assert "file.write" in metadata.permissions + + def test_add_roles(self): + """Test adding roles.""" + metadata = AuthMetadata() + + metadata.add_role("admin") + metadata.add_role("user") + + assert "admin" in metadata.roles + assert "user" in metadata.roles + + def test_to_dict(self): + """Test serializing to dictionary.""" + metadata = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + required_scopes=["profile"], + supports_multi_user=True, + ) + metadata.add_provider( + name="github", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + ) + + data = metadata.to_dict() + + assert data["requirement"] == "required" + assert "profile" in data["required_scopes"] + assert data["supports_multi_user"] is True + assert len(data["providers"]) == 1 + assert data["providers"][0]["name"] == "github" + + def test_from_dict(self): + """Test deserializing from dictionary.""" + data = { + "requirement": "required", + "providers": [ + { + "name": "google", + "type": "oauth2", + "flows": ["oauth2-pkce"], + "authorization_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "scopes": ["profile", "email"], + "supports_refresh": True, + "supports_pkce": True, + } + ], + "required_scopes": ["profile"], + "optional_scopes": ["email"], + "permissions": ["file.read"], + "supports_multi_user": True, + } + + metadata = AuthMetadata.from_dict(data) + + assert metadata.requirement == AuthRequirement.REQUIRED + assert len(metadata.providers) == 1 + assert metadata.providers[0].name == "google" + assert "profile" in metadata.required_scopes + assert "email" in metadata.optional_scopes + assert "file.read" in metadata.permissions + assert metadata.supports_multi_user is True + + def test_roundtrip_serialization(self): + """Test serialization roundtrip.""" + original = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + required_scopes=["profile", "email"], + permissions=["file.read", "file.write"], + supports_multi_user=True, + ) + original.add_provider( + name="github", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + scopes=["repo", "user"], + ) + + # Serialize and deserialize + data = original.to_dict() + restored = AuthMetadata.from_dict(data) + + assert restored.requirement == original.requirement + assert restored.required_scopes == original.required_scopes + assert restored.permissions == original.permissions + assert restored.supports_multi_user == original.supports_multi_user + assert len(restored.providers) == len(original.providers) + assert restored.providers[0].name == original.providers[0].name + + def test_validate_valid_metadata(self): + """Test validating valid metadata.""" + metadata = AuthMetadata(requirement=AuthRequirement.OPTIONAL) + metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + ) + + is_valid, errors = metadata.validate() + + assert is_valid is True + assert len(errors) == 0 + + def test_validate_required_auth_without_providers(self): + """Test validation fails when auth required but no providers.""" + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + + is_valid, errors = metadata.validate() + + assert is_valid is False + assert any("no providers" in error.lower() for error in errors) + + def test_validate_oauth_without_urls(self): + """Test validation fails when OAuth provider missing URLs.""" + metadata = AuthMetadata() + metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + # Missing authorization_url and token_url + ) + + is_valid, errors = metadata.validate() + + assert is_valid is False + assert any("authorization_url" in error for error in errors) + assert any("token_url" in error for error in errors) + + def test_validate_scope_conflict(self): + """Test validation fails when scope is both required and optional.""" + metadata = AuthMetadata() + metadata.add_required_scope("profile") + metadata.add_optional_scope("profile") # Conflict! + + is_valid, errors = metadata.validate() + + assert is_valid is False + assert any("both required and optional" in error.lower() for error in errors) + + def test_schema_structure(self): + """Test that JSON schema has expected structure.""" + assert "$schema" in AUTH_METADATA_SCHEMA + assert "properties" in AUTH_METADATA_SCHEMA + assert "requirement" in AUTH_METADATA_SCHEMA["properties"] + assert "providers" in AUTH_METADATA_SCHEMA["properties"] + + def test_multi_provider_metadata(self): + """Test metadata with multiple providers.""" + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + + metadata.add_provider( + name="google", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + ) + + metadata.add_provider( + name="github", + type="oauth2", + flows=[AuthFlowType.OAUTH2_PKCE], + authorization_url="https://github.com/login/oauth/authorize", + token_url="https://github.com/login/oauth/access_token", + ) + + assert len(metadata.providers) == 2 + provider_names = [p.name for p in metadata.providers] + assert "google" in provider_names + assert "github" in provider_names + + def test_token_refresh_configuration(self): + """Test token refresh configuration.""" + metadata = AuthMetadata( + token_refresh_enabled=True, + ) + + assert metadata.token_refresh_enabled is True + + data = metadata.to_dict() + assert data["token_refresh_enabled"] is True + + def test_session_management_types(self): + """Test different session management types.""" + for session_type in ["server-side", "client-side", "stateless"]: + metadata = AuthMetadata(session_management=session_type) + assert metadata.session_management == session_type + + def test_error_codes_documentation(self): + """Test error code documentation.""" + metadata = AuthMetadata() + metadata.error_codes = { + "auth_required": "Authentication is required to access this resource", + "insufficient_scopes": "Your token lacks required OAuth scopes", + } + + data = metadata.to_dict() + assert "auth_required" in data["error_codes"] + assert "insufficient_scopes" in data["error_codes"] diff --git a/tests/test_manifest.py b/tests/test_manifest.py new file mode 100644 index 0000000..9e3410c --- /dev/null +++ b/tests/test_manifest.py @@ -0,0 +1,524 @@ +""" +Tests for Permission Manifest system. + +Tests for declarative security definitions using manifests. +""" + +import tempfile +from pathlib import Path + +import pytest + +from nextmcp.auth.core import AuthContext, Permission, Role +from nextmcp.auth.manifest import PermissionManifest, ScopeDefinition, ToolPermission + + +class TestScopeDefinition: + """Tests for ScopeDefinition dataclass.""" + + def test_scope_definition_creation(self): + """Test creating a ScopeDefinition.""" + scope = ScopeDefinition( + name="read:data", + description="Read access to data", + oauth_mapping={"github": ["repo:read"], "google": ["drive.readonly"]}, + ) + + assert scope.name == "read:data" + assert scope.description == "Read access to data" + assert scope.oauth_mapping["github"] == ["repo:read"] + assert scope.oauth_mapping["google"] == ["drive.readonly"] + + def test_scope_definition_minimal(self): + """Test creating ScopeDefinition with minimal fields.""" + scope = ScopeDefinition(name="write:data", description="Write data") + + assert scope.name == "write:data" + assert scope.description == "Write data" + assert scope.oauth_mapping == {} + + def test_scope_definition_no_description(self): + """Test ScopeDefinition with empty description.""" + scope = ScopeDefinition(name="delete:data", description="") + + assert scope.name == "delete:data" + assert scope.description == "" + + +class TestToolPermission: + """Tests for ToolPermission dataclass.""" + + def test_tool_permission_creation(self): + """Test creating a ToolPermission.""" + tool = ToolPermission( + tool_name="query_database", + permissions=["read:data"], + scopes=["db.query.read"], + roles=["viewer", "editor"], + description="Execute database queries", + dangerous=False, + ) + + assert tool.tool_name == "query_database" + assert tool.permissions == ["read:data"] + assert tool.scopes == ["db.query.read"] + assert tool.roles == ["viewer", "editor"] + assert tool.description == "Execute database queries" + assert tool.dangerous is False + + def test_tool_permission_minimal(self): + """Test ToolPermission with minimal fields.""" + tool = ToolPermission(tool_name="simple_tool") + + assert tool.tool_name == "simple_tool" + assert tool.permissions == [] + assert tool.scopes == [] + assert tool.roles == [] + assert tool.description == "" + assert tool.dangerous is False + + def test_tool_permission_dangerous_flag(self): + """Test ToolPermission with dangerous flag.""" + tool = ToolPermission(tool_name="delete_all", dangerous=True) + + assert tool.dangerous is True + + +class TestPermissionManifest: + """Tests for PermissionManifest class.""" + + def test_manifest_initialization(self): + """Test creating an empty PermissionManifest.""" + manifest = PermissionManifest() + + assert manifest.scopes == {} + assert manifest.tools == {} + + def test_define_scope(self): + """Test defining a scope in the manifest.""" + manifest = PermissionManifest() + + scope = manifest.define_scope( + name="read:data", description="Read data", oauth_mapping={"github": ["repo:read"]} + ) + + assert isinstance(scope, ScopeDefinition) + assert scope.name == "read:data" + assert "read:data" in manifest.scopes + assert manifest.scopes["read:data"] == scope + + def test_define_multiple_scopes(self): + """Test defining multiple scopes.""" + manifest = PermissionManifest() + + manifest.define_scope("read:data", "Read data") + manifest.define_scope("write:data", "Write data") + manifest.define_scope("delete:data", "Delete data") + + assert len(manifest.scopes) == 3 + assert "read:data" in manifest.scopes + assert "write:data" in manifest.scopes + assert "delete:data" in manifest.scopes + + def test_define_tool_permission(self): + """Test defining a tool permission.""" + manifest = PermissionManifest() + + tool = manifest.define_tool_permission( + tool_name="query_db", + permissions=["read:data"], + scopes=["db.query.read"], + roles=["viewer"], + ) + + assert isinstance(tool, ToolPermission) + assert tool.tool_name == "query_db" + assert "query_db" in manifest.tools + assert manifest.tools["query_db"] == tool + + def test_define_multiple_tool_permissions(self): + """Test defining multiple tool permissions.""" + manifest = PermissionManifest() + + manifest.define_tool_permission("tool1", permissions=["read"]) + manifest.define_tool_permission("tool2", scopes=["scope1"]) + manifest.define_tool_permission("tool3", roles=["admin"]) + + assert len(manifest.tools) == 3 + assert "tool1" in manifest.tools + assert "tool2" in manifest.tools + assert "tool3" in manifest.tools + + def test_load_from_dict(self): + """Test loading manifest from dictionary.""" + manifest = PermissionManifest() + + data = { + "scopes": [ + { + "name": "read:data", + "description": "Read data", + "oauth_mapping": {"github": ["repo:read"]}, + }, + {"name": "write:data", "description": "Write data"}, + ], + "tools": { + "query_db": { + "permissions": ["read:data"], + "scopes": ["db.query.read"], + "roles": ["viewer"], + "description": "Query database", + "dangerous": False, + }, + "delete_data": { + "permissions": ["delete:data"], + "scopes": ["db.delete"], + "roles": ["admin"], + "dangerous": True, + }, + }, + } + + manifest.load_from_dict(data) + + # Check scopes loaded + assert len(manifest.scopes) == 2 + assert "read:data" in manifest.scopes + assert "write:data" in manifest.scopes + assert manifest.scopes["read:data"].oauth_mapping["github"] == ["repo:read"] + + # Check tools loaded + assert len(manifest.tools) == 2 + assert "query_db" in manifest.tools + assert "delete_data" in manifest.tools + assert manifest.tools["query_db"].permissions == ["read:data"] + assert manifest.tools["delete_data"].dangerous is True + + def test_load_from_dict_empty(self): + """Test loading empty manifest.""" + manifest = PermissionManifest() + manifest.load_from_dict({}) + + assert len(manifest.scopes) == 0 + assert len(manifest.tools) == 0 + + def test_load_from_dict_scopes_only(self): + """Test loading manifest with only scopes.""" + manifest = PermissionManifest() + + data = {"scopes": [{"name": "read:data", "description": "Read"}]} + + manifest.load_from_dict(data) + + assert len(manifest.scopes) == 1 + assert len(manifest.tools) == 0 + + def test_load_from_dict_tools_only(self): + """Test loading manifest with only tools.""" + manifest = PermissionManifest() + + data = {"tools": {"tool1": {"permissions": ["read"]}}} + + manifest.load_from_dict(data) + + assert len(manifest.scopes) == 0 + assert len(manifest.tools) == 1 + + def test_to_dict(self): + """Test exporting manifest to dictionary.""" + manifest = PermissionManifest() + + manifest.define_scope("read:data", "Read", {"github": ["repo:read"]}) + manifest.define_tool_permission("query", permissions=["read:data"], dangerous=False) + + result = manifest.to_dict() + + assert "scopes" in result + assert "tools" in result + assert len(result["scopes"]) == 1 + assert len(result["tools"]) == 1 + assert result["scopes"][0]["name"] == "read:data" + assert "query" in result["tools"] + + def test_to_dict_empty(self): + """Test exporting empty manifest.""" + manifest = PermissionManifest() + result = manifest.to_dict() + + assert result == {"scopes": [], "tools": {}} + + +class TestManifestAccessControl: + """Tests for manifest-based access control.""" + + def test_check_tool_access_no_restrictions(self): + """Test tool access when tool is not in manifest.""" + manifest = PermissionManifest() + context = AuthContext(authenticated=True, user_id="user1") + + allowed, error = manifest.check_tool_access("unknown_tool", context) + + assert allowed is True + assert error is None + + def test_check_tool_access_with_role_success(self): + """Test tool access with required role - success.""" + manifest = PermissionManifest() + manifest.define_tool_permission("admin_tool", roles=["admin"]) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_role(Role("admin")) + + allowed, error = manifest.check_tool_access("admin_tool", context) + + assert allowed is True + assert error is None + + def test_check_tool_access_with_role_failure(self): + """Test tool access with required role - missing role.""" + manifest = PermissionManifest() + manifest.define_tool_permission("admin_tool", roles=["admin"]) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_role(Role("viewer")) + + allowed, error = manifest.check_tool_access("admin_tool", context) + + assert allowed is False + assert error is not None + assert "admin" in error + + def test_check_tool_access_with_permission_success(self): + """Test tool access with required permission - success.""" + manifest = PermissionManifest() + manifest.define_tool_permission("read_tool", permissions=["read:data"]) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_permission(Permission("read:data")) + + allowed, error = manifest.check_tool_access("read_tool", context) + + assert allowed is True + assert error is None + + def test_check_tool_access_with_permission_failure(self): + """Test tool access with required permission - missing permission.""" + manifest = PermissionManifest() + manifest.define_tool_permission("write_tool", permissions=["write:data"]) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_permission(Permission("read:data")) + + allowed, error = manifest.check_tool_access("write_tool", context) + + assert allowed is False + assert error is not None + assert "write:data" in error + + def test_check_tool_access_with_scope_success(self): + """Test tool access with required scope - success.""" + manifest = PermissionManifest() + manifest.define_tool_permission("oauth_tool", scopes=["repo:read"]) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_scope("repo:read") + + allowed, error = manifest.check_tool_access("oauth_tool", context) + + assert allowed is True + assert error is None + + def test_check_tool_access_with_scope_failure(self): + """Test tool access with required scope - missing scope.""" + manifest = PermissionManifest() + manifest.define_tool_permission("oauth_tool", scopes=["repo:write"]) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_scope("repo:read") + + allowed, error = manifest.check_tool_access("oauth_tool", context) + + assert allowed is False + assert error is not None + assert "repo:write" in error + + def test_check_tool_access_multiple_requirements_any_match(self): + """Test tool access with multiple requirements - any one matches.""" + manifest = PermissionManifest() + manifest.define_tool_permission( + "multi_tool", permissions=["read:data", "write:data", "admin:all"] + ) + + context = AuthContext(authenticated=True, user_id="user1") + context.add_permission(Permission("write:data")) # Has one of the required permissions + + allowed, error = manifest.check_tool_access("multi_tool", context) + + assert allowed is True + assert error is None + + def test_check_tool_access_combined_role_and_permission(self): + """Test tool access with both role and permission requirements.""" + manifest = PermissionManifest() + manifest.define_tool_permission("combined_tool", roles=["editor"], permissions=["edit:data"]) + + # Has role but not permission + context1 = AuthContext(authenticated=True, user_id="user1") + context1.add_role(Role("editor")) + + allowed, error = manifest.check_tool_access("combined_tool", context1) + # Should fail because missing permission + assert allowed is False + + # Has both role and permission + context2 = AuthContext(authenticated=True, user_id="user2") + context2.add_role(Role("editor")) + context2.add_permission(Permission("edit:data")) + + allowed, error = manifest.check_tool_access("combined_tool", context2) + assert allowed is True + + def test_check_tool_access_all_three_types(self): + """Test tool access requiring role, permission, AND scope.""" + manifest = PermissionManifest() + manifest.define_tool_permission( + "strict_tool", roles=["admin"], permissions=["admin:all"], scopes=["admin:full"] + ) + + # Missing all three + context1 = AuthContext(authenticated=True, user_id="user1") + allowed, _ = manifest.check_tool_access("strict_tool", context1) + assert allowed is False + + # Has role only + context2 = AuthContext(authenticated=True, user_id="user2") + context2.add_role(Role("admin")) + allowed, _ = manifest.check_tool_access("strict_tool", context2) + assert allowed is False # Still needs permission AND scope + + # Has all three + context3 = AuthContext(authenticated=True, user_id="user3") + context3.add_role(Role("admin")) + context3.add_permission(Permission("admin:all")) + context3.add_scope("admin:full") + allowed, _ = manifest.check_tool_access("strict_tool", context3) + assert allowed is True + + +class TestManifestYAMLLoading: + """Tests for loading manifests from YAML files.""" + + def test_load_from_yaml_file(self): + """Test loading manifest from YAML file.""" + manifest = PermissionManifest() + + yaml_content = """ +scopes: + - name: "read:data" + description: "Read access" + oauth_mapping: + github: ["repo:read"] + - name: "write:data" + description: "Write access" + +tools: + query_tool: + scopes: ["read:data"] + permissions: ["read:db"] + description: "Query database" + write_tool: + scopes: ["write:data"] + dangerous: true +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + yaml_path = f.name + + try: + manifest.load_from_yaml(yaml_path) + + assert len(manifest.scopes) == 2 + assert "read:data" in manifest.scopes + assert "write:data" in manifest.scopes + + assert len(manifest.tools) == 2 + assert "query_tool" in manifest.tools + assert "write_tool" in manifest.tools + assert manifest.tools["write_tool"].dangerous is True + + finally: + Path(yaml_path).unlink() + + def test_load_from_yaml_invalid_file(self): + """Test loading from non-existent YAML file raises error.""" + manifest = PermissionManifest() + + with pytest.raises(FileNotFoundError): + manifest.load_from_yaml("/nonexistent/file.yaml") + + +class TestManifestEdgeCases: + """Tests for edge cases in manifest handling.""" + + def test_scope_with_special_characters(self): + """Test scopes with special characters.""" + manifest = PermissionManifest() + + manifest.define_scope("https://api.example.com/read", "URL scope") + manifest.define_scope("admin:*", "Wildcard scope") + + assert len(manifest.scopes) == 2 + + def test_tool_with_empty_requirements(self): + """Test tool with all empty requirement lists.""" + manifest = PermissionManifest() + manifest.define_tool_permission("open_tool", permissions=[], scopes=[], roles=[]) + + context = AuthContext(authenticated=True, user_id="user1") + allowed, error = manifest.check_tool_access("open_tool", context) + + # Should allow access since no requirements + assert allowed is True + assert error is None + + def test_manifest_overwrite_scope(self): + """Test that redefining a scope overwrites it.""" + manifest = PermissionManifest() + + manifest.define_scope("data:read", "First description") + manifest.define_scope("data:read", "Second description") + + assert len(manifest.scopes) == 1 + assert manifest.scopes["data:read"].description == "Second description" + + def test_manifest_overwrite_tool(self): + """Test that redefining a tool overwrites it.""" + manifest = PermissionManifest() + + manifest.define_tool_permission("tool1", permissions=["old"]) + manifest.define_tool_permission("tool1", permissions=["new"]) + + assert len(manifest.tools) == 1 + assert manifest.tools["tool1"].permissions == ["new"] + + def test_to_dict_round_trip(self): + """Test that to_dict and load_from_dict are inverses.""" + manifest1 = PermissionManifest() + + manifest1.define_scope("read", "Read", {"gh": ["repo:read"]}) + manifest1.define_tool_permission("tool1", permissions=["read"], dangerous=True) + + # Export to dict + data = manifest1.to_dict() + + # Load into new manifest + manifest2 = PermissionManifest() + manifest2.load_from_dict(data) + + # Should be equivalent + assert len(manifest2.scopes) == len(manifest1.scopes) + assert len(manifest2.tools) == len(manifest1.tools) + assert "read" in manifest2.scopes + assert "tool1" in manifest2.tools + assert manifest2.tools["tool1"].dangerous is True diff --git a/tests/test_manifest_middleware.py b/tests/test_manifest_middleware.py new file mode 100644 index 0000000..1a53545 --- /dev/null +++ b/tests/test_manifest_middleware.py @@ -0,0 +1,371 @@ +""" +Tests for manifest-middleware integration. + +Tests @requires_manifest decorator for enforcing PermissionManifest at runtime. +""" + +import pytest + +from nextmcp.auth.core import AuthContext, Permission, Role +from nextmcp.auth.errors import ManifestViolationError +from nextmcp.auth.manifest import PermissionManifest +from nextmcp.auth.middleware import requires_auth_async, requires_manifest_async +from nextmcp.auth.providers import APIKeyProvider + + +class MockAuthProvider(APIKeyProvider): + """Mock auth provider for testing.""" + + def __init__(self, **kwargs): + # Accept custom valid_keys or use default + if 'valid_keys' not in kwargs: + kwargs['valid_keys'] = {"test_key": {"user_id": "user123"}} + super().__init__(**kwargs) + + +@pytest.fixture +def auth_provider(): + """Create mock auth provider.""" + return MockAuthProvider() + + +@pytest.fixture +def simple_manifest(): + """Create a simple permission manifest for testing.""" + manifest = PermissionManifest() + + # Define a tool requiring admin role + manifest.define_tool_permission( + tool_name="admin_tool", + roles=["admin"], + ) + + # Define a tool requiring read permission + manifest.define_tool_permission( + tool_name="read_tool", + permissions=["read:data"], + ) + + # Define a tool requiring OAuth scope + manifest.define_tool_permission( + tool_name="oauth_tool", + scopes=["repo:read"], + ) + + # Define a tool with multiple requirements + manifest.define_tool_permission( + tool_name="strict_tool", + roles=["admin"], + permissions=["write:data"], + scopes=["admin:full"], + ) + + return manifest + + +class TestRequiresManifestAsync: + """Tests for @requires_manifest_async decorator.""" + + @pytest.mark.asyncio + async def test_manifest_allows_unrestricted_tool(self, auth_provider, simple_manifest): + """Test that tools not in manifest are allowed.""" + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="unrestricted_tool") + async def unrestricted_tool(auth: AuthContext): + return f"Success for {auth.user_id}" + + result = await unrestricted_tool(auth={"api_key": "test_key"}) + assert result == "Success for user123" + + @pytest.mark.asyncio + async def test_manifest_allows_with_required_role(self, auth_provider, simple_manifest): + """Test manifest allows access when user has required role.""" + + # Mock provider that adds admin role + class AdminAuthProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_role(Role("admin")) + return result + + admin_provider = AdminAuthProvider(valid_keys={"admin_key": {"user_id": "admin_user"}}) + + @requires_auth_async(provider=admin_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="admin_tool") + async def admin_tool(auth: AuthContext): + return "Admin action completed" + + result = await admin_tool(auth={"api_key": "admin_key"}) + assert result == "Admin action completed" + + @pytest.mark.asyncio + async def test_manifest_denies_without_required_role(self, auth_provider, simple_manifest): + """Test manifest denies access when user lacks required role.""" + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="admin_tool") + async def admin_tool(auth: AuthContext): + return "Should not reach here" + + with pytest.raises(ManifestViolationError) as exc_info: + await admin_tool(auth={"api_key": "test_key"}) + + assert exc_info.value.tool_name == "admin_tool" + assert "admin" in exc_info.value.required_roles + + @pytest.mark.asyncio + async def test_manifest_allows_with_required_permission(self, auth_provider, simple_manifest): + """Test manifest allows access when user has required permission.""" + + # Mock provider that adds read permission + class ReadAuthProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_permission(Permission("read:data")) + return result + + read_provider = ReadAuthProvider(valid_keys={"read_key": {"user_id": "read_user"}}) + + @requires_auth_async(provider=read_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="read_tool") + async def read_tool(auth: AuthContext): + return "Read completed" + + result = await read_tool(auth={"api_key": "read_key"}) + assert result == "Read completed" + + @pytest.mark.asyncio + async def test_manifest_denies_without_required_permission(self, auth_provider, simple_manifest): + """Test manifest denies access when user lacks required permission.""" + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="read_tool") + async def read_tool(auth: AuthContext): + return "Should not reach here" + + with pytest.raises(ManifestViolationError) as exc_info: + await read_tool(auth={"api_key": "test_key"}) + + assert exc_info.value.tool_name == "read_tool" + assert "read:data" in exc_info.value.required_permissions + + @pytest.mark.asyncio + async def test_manifest_allows_with_required_scope(self, auth_provider, simple_manifest): + """Test manifest allows access when user has required scope.""" + + # Mock provider that adds OAuth scope + class OAuthAuthProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_scope("repo:read") + return result + + oauth_provider = OAuthAuthProvider(valid_keys={"oauth_key": {"user_id": "oauth_user"}}) + + @requires_auth_async(provider=oauth_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="oauth_tool") + async def oauth_tool(auth: AuthContext): + return "OAuth action completed" + + result = await oauth_tool(auth={"api_key": "oauth_key"}) + assert result == "OAuth action completed" + + @pytest.mark.asyncio + async def test_manifest_denies_without_required_scope(self, auth_provider, simple_manifest): + """Test manifest denies access when user lacks required scope.""" + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="oauth_tool") + async def oauth_tool(auth: AuthContext): + return "Should not reach here" + + with pytest.raises(ManifestViolationError) as exc_info: + await oauth_tool(auth={"api_key": "test_key"}) + + assert exc_info.value.tool_name == "oauth_tool" + assert "repo:read" in exc_info.value.required_scopes + + @pytest.mark.asyncio + async def test_manifest_requires_all_requirement_types(self, auth_provider, simple_manifest): + """Test manifest requires ALL types (role AND permission AND scope).""" + + # Provider with role only + class RoleOnlyProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_role(Role("admin")) + return result + + role_provider = RoleOnlyProvider(valid_keys={"key": {"user_id": "user"}}) + + @requires_auth_async(provider=role_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="strict_tool") + async def strict_tool(auth: AuthContext): + return "Should not reach here" + + # Should fail because missing permission and scope + with pytest.raises(ManifestViolationError): + await strict_tool(auth={"api_key": "key"}) + + @pytest.mark.asyncio + async def test_manifest_allows_with_all_requirements(self, auth_provider, simple_manifest): + """Test manifest allows when user has ALL requirements.""" + + # Provider with all requirements + class FullAuthProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_role(Role("admin")) + result.context.add_permission(Permission("write:data")) + result.context.add_scope("admin:full") + return result + + full_provider = FullAuthProvider(valid_keys={"full_key": {"user_id": "full_user"}}) + + @requires_auth_async(provider=full_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="strict_tool") + async def strict_tool(auth: AuthContext): + return "All requirements met" + + result = await strict_tool(auth={"api_key": "full_key"}) + assert result == "All requirements met" + + @pytest.mark.asyncio + async def test_manifest_error_contains_user_context(self, auth_provider, simple_manifest): + """Test ManifestViolationError contains user context.""" + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=simple_manifest, tool_name="admin_tool") + async def admin_tool(auth: AuthContext): + return "Should not reach here" + + with pytest.raises(ManifestViolationError) as exc_info: + await admin_tool(auth={"api_key": "test_key"}) + + error = exc_info.value + assert error.user_id == "user123" + assert error.auth_context is not None + assert error.auth_context.user_id == "user123" + + @pytest.mark.asyncio + async def test_manifest_decorator_without_tool_name(self, auth_provider, simple_manifest): + """Test decorator can infer tool name from function name.""" + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=simple_manifest) # No tool_name specified + async def read_tool(auth: AuthContext): + return "Should not reach here" + + # Should use function name "read_tool" and check manifest + with pytest.raises(ManifestViolationError) as exc_info: + await read_tool(auth={"api_key": "test_key"}) + + assert exc_info.value.tool_name == "read_tool" + + @pytest.mark.asyncio + async def test_manifest_with_empty_manifest(self, auth_provider): + """Test decorator with empty manifest (no restrictions).""" + empty_manifest = PermissionManifest() + + @requires_auth_async(provider=auth_provider) + @requires_manifest_async(manifest=empty_manifest, tool_name="any_tool") + async def any_tool(auth: AuthContext): + return "Allowed" + + result = await any_tool(auth={"api_key": "test_key"}) + assert result == "Allowed" + + +class TestManifestIntegrationPatterns: + """Tests for common manifest integration patterns.""" + + @pytest.mark.asyncio + async def test_multiple_tools_with_shared_manifest(self, auth_provider): + """Test multiple tools sharing a manifest.""" + manifest = PermissionManifest() + manifest.define_tool_permission("tool1", roles=["editor"]) + manifest.define_tool_permission("tool2", roles=["viewer", "editor"]) + + # Provider that adds editor role + class EditorProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_role(Role("editor")) + return result + + editor_provider = EditorProvider(valid_keys={"key": {"user_id": "user"}}) + + @requires_auth_async(provider=editor_provider) + @requires_manifest_async(manifest=manifest, tool_name="tool1") + async def tool1(auth: AuthContext): + return "Tool 1" + + @requires_auth_async(provider=editor_provider) + @requires_manifest_async(manifest=manifest, tool_name="tool2") + async def tool2(auth: AuthContext): + return "Tool 2" + + # Both should succeed + assert await tool1(auth={"api_key": "key"}) == "Tool 1" + assert await tool2(auth={"api_key": "key"}) == "Tool 2" + + @pytest.mark.asyncio + async def test_manifest_with_wildcard_permission(self, auth_provider): + """Test manifest with wildcard permission matching.""" + manifest = PermissionManifest() + manifest.define_tool_permission("data_tool", permissions=["data:read"]) + + # Provider that adds wildcard permission + class WildcardProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_permission(Permission("data:*")) + return result + + wildcard_provider = WildcardProvider(valid_keys={"key": {"user_id": "user"}}) + + @requires_auth_async(provider=wildcard_provider) + @requires_manifest_async(manifest=manifest, tool_name="data_tool") + async def data_tool(auth: AuthContext): + return "Data access granted" + + result = await data_tool(auth={"api_key": "key"}) + assert result == "Data access granted" + + @pytest.mark.asyncio + async def test_manifest_loaded_from_dict(self, auth_provider): + """Test manifest loaded from dictionary configuration.""" + manifest = PermissionManifest() + manifest.load_from_dict({ + "tools": { + "query_db": { + "roles": ["analyst"], + } + } + }) + + # Provider with analyst role + class AnalystProvider(APIKeyProvider): + async def authenticate(self, credentials): + result = await super().authenticate(credentials) + if result.success: + result.context.add_role(Role("analyst")) + return result + + analyst_provider = AnalystProvider(valid_keys={"key": {"user_id": "user"}}) + + @requires_auth_async(provider=analyst_provider) + @requires_manifest_async(manifest=manifest, tool_name="query_db") + async def query_db(auth: AuthContext): + return "Query executed" + + result = await query_db(auth={"api_key": "key"}) + assert result == "Query executed" diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..25e9094 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,611 @@ +""" +Tests for OAuth 2.0 authentication providers. + +Tests for PKCE, OAuth base provider, and specific OAuth providers (GitHub, Google). +""" + +import base64 +import hashlib +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from nextmcp.auth.oauth import OAuthConfig, OAuthProvider, PKCEChallenge +from nextmcp.auth.oauth_providers import GitHubOAuthProvider, GoogleOAuthProvider + + +def create_mock_aiohttp_response(status: int, json_data: dict | None = None, text_data: str = ""): + """Helper to create a properly mocked aiohttp response.""" + mock_response = AsyncMock() + mock_response.status = status + + # Set up headers - default to JSON if json_data is provided + headers = {} + if json_data is not None: + headers["Content-Type"] = "application/json" + mock_response.json = AsyncMock(return_value=json_data) + if text_data: + headers["Content-Type"] = "application/x-www-form-urlencoded" + mock_response.text = AsyncMock(return_value=text_data) + + mock_response.headers = headers + mock_response.__aenter__.return_value = mock_response + # __aexit__ must return False/None to not suppress exceptions + mock_response.__aexit__ = AsyncMock(return_value=False) + return mock_response + + +def create_mock_aiohttp_session(**responses): + """ + Helper to create a properly mocked aiohttp ClientSession. + + Args: + **responses: Keyword arguments where key is method name ('get', 'post') + and value is the mock response + """ + mock_session = MagicMock() + + for method, response in responses.items(): + # Create a method that returns the response (which is already an async context manager) + method_mock = MagicMock(return_value=response) + setattr(mock_session, method, method_mock) + + return mock_session + + +class TestPKCEChallenge: + """Tests for PKCE challenge generation.""" + + def test_pkce_generation(self): + """Test PKCE challenge is generated correctly.""" + challenge = PKCEChallenge.generate() + + assert isinstance(challenge.verifier, str) + assert isinstance(challenge.challenge, str) + assert challenge.method == "S256" + + # Verifier should be 43+ characters (base64url encoded 32 bytes) + assert len(challenge.verifier) >= 43 + + # Challenge should be 43+ characters (base64url encoded SHA256 hash) + assert len(challenge.challenge) >= 43 + + def test_pkce_verifier_uniqueness(self): + """Test that each PKCE generation produces unique verifiers.""" + challenge1 = PKCEChallenge.generate() + challenge2 = PKCEChallenge.generate() + + assert challenge1.verifier != challenge2.verifier + assert challenge1.challenge != challenge2.challenge + + def test_pkce_challenge_derivation(self): + """Test that challenge is correctly derived from verifier.""" + challenge = PKCEChallenge.generate() + + # Manually compute challenge from verifier + expected_challenge = base64.urlsafe_b64encode( + hashlib.sha256(challenge.verifier.encode("utf-8")).digest() + ).decode("utf-8").rstrip("=") + + assert challenge.challenge == expected_challenge + + def test_pkce_no_padding(self): + """Test that PKCE values don't contain base64 padding.""" + challenge = PKCEChallenge.generate() + + # Base64url encoding should not have padding (=) + assert "=" not in challenge.verifier + assert "=" not in challenge.challenge + + +class TestOAuthConfig: + """Tests for OAuth configuration.""" + + def test_oauth_config_creation(self): + """Test OAuth configuration creation.""" + config = OAuthConfig( + client_id="test_client_id", + client_secret="test_secret", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + redirect_uri="http://localhost:8080/callback", + scope=["read", "write"], + ) + + assert config.client_id == "test_client_id" + assert config.client_secret == "test_secret" + assert config.authorization_url == "https://provider.com/oauth/authorize" + assert config.token_url == "https://provider.com/oauth/token" + assert config.redirect_uri == "http://localhost:8080/callback" + assert config.scope == ["read", "write"] + + def test_oauth_config_optional_secret(self): + """Test OAuth config with optional client secret (for PKCE).""" + config = OAuthConfig( + client_id="test_client_id", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + ) + + assert config.client_secret is None + + +class MockOAuthProvider(OAuthProvider): + """Mock OAuth provider for testing base class.""" + + async def get_user_info(self, access_token: str): + """Mock user info retrieval.""" + return { + "id": "12345", + "login": "testuser", + "email": "test@example.com", + } + + def get_additional_auth_params(self): + """Mock additional auth params.""" + return {"extra_param": "value"} + + def extract_user_id(self, user_info): + """Extract user ID from user info.""" + return str(user_info["id"]) + + +class TestOAuthProvider: + """Tests for OAuth base provider.""" + + def test_provider_initialization(self): + """Test OAuth provider initialization.""" + config = OAuthConfig( + client_id="test_client", + client_secret="test_secret", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + + assert provider.config == config + assert provider._pending_auth == {} + + def test_generate_authorization_url(self): + """Test OAuth authorization URL generation.""" + config = OAuthConfig( + client_id="test_client", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + redirect_uri="http://localhost:8080/callback", + scope=["read", "write"], + ) + + provider = MockOAuthProvider(config) + auth_data = provider.generate_authorization_url() + + # Check returned data + assert "url" in auth_data + assert "state" in auth_data + assert "verifier" in auth_data + + # Check URL contains required parameters + url = auth_data["url"] + assert "https://provider.com/oauth/authorize" in url + assert "client_id=test_client" in url + # URL encoded redirect_uri + assert ("redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback" in url or + "redirect_uri=http://localhost:8080/callback" in url) + assert "response_type=code" in url + assert f"state={auth_data['state']}" in url + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + # URL encoded scopes (+ or %20 for spaces) + assert "scope=read+write" in url or "scope=read%20write" in url + + # Check PKCE is stored + assert auth_data["state"] in provider._pending_auth + + def test_generate_authorization_url_custom_state(self): + """Test authorization URL generation with custom state.""" + config = OAuthConfig( + client_id="test_client", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + custom_state = "my_custom_state_123" + auth_data = provider.generate_authorization_url(state=custom_state) + + assert auth_data["state"] == custom_state + assert custom_state in provider._pending_auth + + @pytest.mark.asyncio + async def test_exchange_code_for_token(self): + """Test exchanging authorization code for access token.""" + config = OAuthConfig( + client_id="test_client", + client_secret="test_secret", + token_url="https://provider.com/oauth/token", + redirect_uri="http://localhost:8080/callback", + ) + + provider = MockOAuthProvider(config) + + # Generate auth URL to create PKCE + auth_data = provider.generate_authorization_url() + state = auth_data["state"] + verifier = auth_data["verifier"] + + # Mock the HTTP response + mock_response = { + "access_token": "mock_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "mock_refresh_token", + "scope": "read write", + } + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(200, json_data=mock_response) + mock_session_inst = create_mock_aiohttp_session(post=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__.return_value = AsyncMock() + + # Exchange code for token + token_data = await provider.exchange_code_for_token( + code="auth_code_123", state=state + ) + + assert token_data["access_token"] == "mock_access_token" + assert token_data["refresh_token"] == "mock_refresh_token" + + # PKCE should be consumed + assert state not in provider._pending_auth + + @pytest.mark.asyncio + async def test_exchange_code_with_external_verifier(self): + """Test token exchange with externally stored verifier.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + redirect_uri="http://localhost:8080/callback", + ) + + provider = MockOAuthProvider(config) + + # Don't use provider's generate_authorization_url + # Instead, provide verifier manually + external_verifier = PKCEChallenge.generate().verifier + + mock_response = { + "access_token": "mock_access_token", + "token_type": "Bearer", + } + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(200, json_data=mock_response) + mock_session_inst = create_mock_aiohttp_session(post=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__.return_value = AsyncMock() + + # Exchange with external verifier + token_data = await provider.exchange_code_for_token( + code="auth_code_123", + state="external_state", + verifier=external_verifier, + ) + + assert token_data["access_token"] == "mock_access_token" + + @pytest.mark.asyncio + async def test_exchange_code_invalid_state(self): + """Test token exchange with invalid state raises error.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + + # Try to exchange without generating auth URL first + with pytest.raises(ValueError, match="Invalid state or expired authorization"): + await provider.exchange_code_for_token( + code="auth_code_123", state="invalid_state" + ) + + @pytest.mark.asyncio + async def test_exchange_code_token_error(self): + """Test token exchange handles error responses.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + auth_data = provider.generate_authorization_url() + + # Mock error response + mock_error = {"error": "invalid_grant", "error_description": "Code expired"} + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(400, json_data=mock_error) + mock_session_inst = create_mock_aiohttp_session(post=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(ValueError, match="Token exchange failed"): + await provider.exchange_code_for_token( + code="invalid_code", state=auth_data["state"] + ) + + @pytest.mark.asyncio + async def test_refresh_access_token(self): + """Test refreshing access token.""" + config = OAuthConfig( + client_id="test_client", + client_secret="test_secret", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + + mock_response = { + "access_token": "new_access_token", + "token_type": "Bearer", + "expires_in": 3600, + } + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(200, json_data=mock_response) + mock_session_inst = create_mock_aiohttp_session(post=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__.return_value = AsyncMock() + + token_data = await provider.refresh_access_token("old_refresh_token") + + assert token_data["access_token"] == "new_access_token" + + @pytest.mark.asyncio + async def test_refresh_token_error(self): + """Test refresh token handles error responses.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + + mock_error = {"error": "invalid_grant", "error_description": "Refresh token expired"} + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(400, json_data=mock_error) + mock_session_inst = create_mock_aiohttp_session(post=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(ValueError, match="Token refresh failed"): + await provider.refresh_access_token("invalid_refresh_token") + + @pytest.mark.asyncio + async def test_authenticate_with_access_token(self): + """Test authentication using OAuth access token.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + + credentials = { + "access_token": "valid_access_token", + "refresh_token": "valid_refresh_token", + "scopes": ["read", "write"], + } + + result = await provider.authenticate(credentials) + + assert result.success is True + assert result.context is not None + assert result.context.authenticated is True + assert result.context.user_id == "12345" + assert result.context.username == "testuser" + + # OAuth provider should add scopes as permissions + assert result.context.has_permission("read") + assert result.context.has_permission("write") + + # Metadata should contain OAuth info + assert result.context.metadata["oauth_provider"] == "MockOAuthProvider" + assert result.context.metadata["access_token"] == "valid_access_token" + assert result.context.metadata["refresh_token"] == "valid_refresh_token" + + @pytest.mark.asyncio + async def test_authenticate_missing_access_token(self): + """Test authentication fails without access token.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + ) + + provider = MockOAuthProvider(config) + + result = await provider.authenticate({}) + + assert result.success is False + assert result.error == "Missing access_token" + + @pytest.mark.asyncio + async def test_authenticate_user_info_error(self): + """Test authentication fails when user info retrieval fails.""" + config = OAuthConfig( + client_id="test_client", + token_url="https://provider.com/oauth/token", + ) + + # Create provider that raises error on get_user_info + class FailingOAuthProvider(MockOAuthProvider): + async def get_user_info(self, access_token): + raise Exception("User info API error") + + provider = FailingOAuthProvider(config) + + result = await provider.authenticate({"access_token": "token"}) + + assert result.success is False + assert "OAuth authentication failed" in result.error + + +class TestGitHubOAuthProvider: + """Tests for GitHub OAuth provider.""" + + def test_github_provider_initialization(self): + """Test GitHub provider initialization with default config.""" + provider = GitHubOAuthProvider( + client_id="github_client_id", + client_secret="github_secret", + ) + + assert provider.config.client_id == "github_client_id" + assert provider.config.client_secret == "github_secret" + assert provider.config.authorization_url == "https://github.com/login/oauth/authorize" + assert provider.config.token_url == "https://github.com/login/oauth/access_token" + assert provider.config.scope == ["read:user"] + + def test_github_provider_custom_scope(self): + """Test GitHub provider with custom scopes.""" + provider = GitHubOAuthProvider( + client_id="github_client_id", scope=["repo", "user:email"] + ) + + assert provider.config.scope == ["repo", "user:email"] + + @pytest.mark.asyncio + async def test_github_get_user_info(self): + """Test GitHub user info retrieval.""" + provider = GitHubOAuthProvider(client_id="test_client") + + mock_user_data = { + "id": 12345, + "login": "octocat", + "email": "octocat@github.com", + "name": "The Octocat", + } + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(200, json_data=mock_user_data, text_data="success") + mock_session_inst = create_mock_aiohttp_session(get=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__.return_value = AsyncMock() + + user_info = await provider.get_user_info("test_access_token") + + assert user_info["id"] == 12345 + assert user_info["login"] == "octocat" + + @pytest.mark.asyncio + async def test_github_get_user_info_error(self): + """Test GitHub user info retrieval with error.""" + provider = GitHubOAuthProvider(client_id="test_client") + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(401, text_data="Unauthorized") + mock_session_inst = create_mock_aiohttp_session(get=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(ValueError, match="Failed to get user info"): + await provider.get_user_info("invalid_token") + + def test_github_extract_user_id(self): + """Test extracting user ID from GitHub user info.""" + provider = GitHubOAuthProvider(client_id="test_client") + + user_info = {"id": 12345, "login": "octocat"} + user_id = provider.extract_user_id(user_info) + + assert user_id == "12345" + + def test_github_extract_username(self): + """Test extracting username from GitHub user info.""" + provider = GitHubOAuthProvider(client_id="test_client") + + user_info = {"id": 12345, "login": "octocat"} + username = provider.extract_username(user_info) + + assert username == "octocat" + + +class TestGoogleOAuthProvider: + """Tests for Google OAuth provider.""" + + def test_google_provider_initialization(self): + """Test Google provider initialization with default config.""" + provider = GoogleOAuthProvider( + client_id="google_client_id", + client_secret="google_secret", + ) + + assert provider.config.client_id == "google_client_id" + assert provider.config.client_secret == "google_secret" + assert provider.config.authorization_url == "https://accounts.google.com/o/oauth2/v2/auth" + assert provider.config.token_url == "https://oauth2.googleapis.com/token" + assert provider.config.scope == ["openid", "email", "profile"] + + def test_google_additional_auth_params(self): + """Test Google-specific auth parameters.""" + provider = GoogleOAuthProvider( + client_id="google_client_id", + client_secret="google_secret", + ) + + params = provider.get_additional_auth_params() + + assert params["access_type"] == "offline" + assert params["prompt"] == "consent" + + @pytest.mark.asyncio + async def test_google_get_user_info(self): + """Test Google user info retrieval.""" + provider = GoogleOAuthProvider( + client_id="test_client", + client_secret="test_secret", + ) + + mock_user_data = { + "id": "google123", + "email": "user@gmail.com", + "name": "Test User", + "picture": "https://example.com/photo.jpg", + } + + with patch("aiohttp.ClientSession") as MockSession: + mock_resp = create_mock_aiohttp_response(200, json_data=mock_user_data, text_data="success") + mock_session_inst = create_mock_aiohttp_session(get=mock_resp) + MockSession.return_value.__aenter__.return_value = mock_session_inst + MockSession.return_value.__aexit__.return_value = AsyncMock() + + user_info = await provider.get_user_info("test_access_token") + + assert user_info["id"] == "google123" + assert user_info["email"] == "user@gmail.com" + + def test_google_extract_user_id(self): + """Test extracting user ID from Google user info.""" + provider = GoogleOAuthProvider( + client_id="test_client", + client_secret="test_secret", + ) + + user_info = {"id": "google123", "email": "user@gmail.com"} + user_id = provider.extract_user_id(user_info) + + assert user_id == "google123" + + def test_google_extract_username(self): + """Test extracting username from Google user info.""" + provider = GoogleOAuthProvider( + client_id="test_client", + client_secret="test_secret", + ) + + user_info = {"id": "google123", "email": "user@gmail.com"} + username = provider.extract_username(user_info) + + assert username == "user@gmail.com" diff --git a/tests/test_oauth_integration.py b/tests/test_oauth_integration.py new file mode 100644 index 0000000..41b6a7f --- /dev/null +++ b/tests/test_oauth_integration.py @@ -0,0 +1,423 @@ +""" +OAuth Integration Tests - Requires Real Credentials + +These tests perform actual OAuth flows with GitHub and Google. +They are skipped by default and require: +1. OAuth app credentials (client ID and secret) +2. Pre-obtained access tokens (for testing authenticated endpoints) + +Setup Instructions: +See: docs/OAUTH_TESTING_SETUP.md + +Run these tests with: + pytest tests/test_oauth_integration.py -v -m integration + +Or skip them (default): + pytest # automatically skips integration tests +""" + +import os + +import pytest + +from nextmcp.auth import GitHubOAuthProvider, GoogleOAuthProvider + +# Mark all tests in this module as integration tests +pytestmark = pytest.mark.integration + +# ============================================================================ +# CONFIGURATION - Tests skip if these environment variables are not set +# ============================================================================ + +GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") +GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") +GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN") # Pre-obtained for testing + +GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID") +GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") +GOOGLE_ACCESS_TOKEN = os.getenv("GOOGLE_ACCESS_TOKEN") # Pre-obtained for testing +GOOGLE_REFRESH_TOKEN = os.getenv("GOOGLE_REFRESH_TOKEN") # For refresh tests + +# Skip conditions +skip_github_no_creds = pytest.mark.skipif( + not GITHUB_CLIENT_ID or not GITHUB_CLIENT_SECRET, + reason="GitHub OAuth credentials not configured. Set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET", +) + +skip_github_no_token = pytest.mark.skipif( + not GITHUB_ACCESS_TOKEN, + reason="GitHub access token not configured. Set GITHUB_ACCESS_TOKEN", +) + +skip_google_no_creds = pytest.mark.skipif( + not GOOGLE_CLIENT_ID or not GOOGLE_CLIENT_SECRET, + reason="Google OAuth credentials not configured. Set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET", +) + +skip_google_no_token = pytest.mark.skipif( + not GOOGLE_ACCESS_TOKEN, + reason="Google access token not configured. Set GOOGLE_ACCESS_TOKEN", +) + +skip_google_no_refresh = pytest.mark.skipif( + not GOOGLE_REFRESH_TOKEN, + reason="Google refresh token not configured. Set GOOGLE_REFRESH_TOKEN", +) + + +# ============================================================================ +# GITHUB OAUTH INTEGRATION TESTS +# ============================================================================ + + +class TestGitHubOAuthIntegration: + """Integration tests for GitHub OAuth provider.""" + + @skip_github_no_creds + def test_github_authorization_url_generation(self): + """ + Test generating real GitHub authorization URL. + + This test verifies the authorization URL is correctly formatted + and can be used to start the OAuth flow. + """ + provider = GitHubOAuthProvider( + client_id=GITHUB_CLIENT_ID, + client_secret=GITHUB_CLIENT_SECRET, + scope=["read:user", "repo"], + ) + + auth_data = provider.generate_authorization_url() + + # Verify structure + assert "url" in auth_data + assert "state" in auth_data + assert "verifier" in auth_data + + # Verify URL format + url = auth_data["url"] + assert url.startswith("https://github.com/login/oauth/authorize") + assert f"client_id={GITHUB_CLIENT_ID}" in url + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + assert "scope=read%3Auser+repo" in url or "scope=read:user+repo" in url + + print(f"\n✓ GitHub authorization URL generated successfully") + print(f" URL: {url[:80]}...") + print(f" State: {auth_data['state']}") + + @skip_github_no_creds + @skip_github_no_token + @pytest.mark.asyncio + async def test_github_get_user_info(self): + """ + Test retrieving user info from GitHub with real access token. + + Requires: GITHUB_ACCESS_TOKEN environment variable + + To get a token, see: docs/OAUTH_TESTING_SETUP.md + """ + provider = GitHubOAuthProvider( + client_id=GITHUB_CLIENT_ID, + client_secret=GITHUB_CLIENT_SECRET, + ) + + # Get user info using access token + user_info = await provider.get_user_info(GITHUB_ACCESS_TOKEN) + + # Verify response structure + assert "id" in user_info + assert "login" in user_info + + print(f"\n✓ GitHub user info retrieved successfully") + print(f" User ID: {user_info.get('id')}") + print(f" Username: {user_info.get('login')}") + print(f" Name: {user_info.get('name', 'N/A')}") + print(f" Email: {user_info.get('email', 'N/A')}") + + @skip_github_no_creds + @skip_github_no_token + @pytest.mark.asyncio + async def test_github_authentication_with_token(self): + """ + Test full GitHub authentication flow with access token. + + This tests the authenticate() method which would normally be + called after the OAuth flow completes. + """ + provider = GitHubOAuthProvider( + client_id=GITHUB_CLIENT_ID, + client_secret=GITHUB_CLIENT_SECRET, + ) + + # Authenticate with access token + result = await provider.authenticate({ + "access_token": GITHUB_ACCESS_TOKEN, + "scopes": ["read:user", "repo"], + }) + + # Verify authentication success + assert result.success is True + assert result.context is not None + assert result.context.authenticated is True + assert result.context.user_id is not None + + # Verify scopes were added + assert len(result.context.scopes) > 0 + + print(f"\n✓ GitHub authentication successful") + print(f" User ID: {result.context.user_id}") + print(f" Username: {result.context.username}") + print(f" Scopes: {list(result.context.scopes)}") + print(f" Permissions: {[p.name for p in result.context.permissions]}") + + @skip_github_no_creds + @pytest.mark.asyncio + async def test_github_invalid_token_handling(self): + """ + Test that invalid tokens are properly rejected. + + This ensures error handling works correctly. + """ + provider = GitHubOAuthProvider( + client_id=GITHUB_CLIENT_ID, + client_secret=GITHUB_CLIENT_SECRET, + ) + + # Try to authenticate with invalid token + result = await provider.authenticate({ + "access_token": "invalid_token_12345", + "scopes": ["read:user"], + }) + + # Should fail gracefully + assert result.success is False + assert result.error is not None + + print(f"\n✓ Invalid GitHub token correctly rejected") + print(f" Error: {result.error}") + + +# ============================================================================ +# GOOGLE OAUTH INTEGRATION TESTS +# ============================================================================ + + +class TestGoogleOAuthIntegration: + """Integration tests for Google OAuth provider.""" + + @skip_google_no_creds + def test_google_authorization_url_generation(self): + """ + Test generating real Google authorization URL. + + This test verifies the authorization URL includes required parameters + for Google OAuth with offline access (refresh tokens). + """ + provider = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + scope=[ + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + ], + ) + + auth_data = provider.generate_authorization_url() + + # Verify structure + assert "url" in auth_data + assert "state" in auth_data + assert "verifier" in auth_data + + # Verify URL format + url = auth_data["url"] + assert url.startswith("https://accounts.google.com/o/oauth2/v2/auth") + assert f"client_id={GOOGLE_CLIENT_ID}" in url + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + assert "access_type=offline" in url # Important for refresh tokens + assert "prompt=consent" in url + + print(f"\n✓ Google authorization URL generated successfully") + print(f" URL: {url[:80]}...") + print(f" State: {auth_data['state']}") + print(f" Includes offline access: Yes") + + @skip_google_no_creds + @skip_google_no_token + @pytest.mark.asyncio + async def test_google_get_user_info(self): + """ + Test retrieving user info from Google with real access token. + + Requires: GOOGLE_ACCESS_TOKEN environment variable + + To get a token, see: docs/OAUTH_TESTING_SETUP.md + """ + provider = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + ) + + # Get user info using access token + user_info = await provider.get_user_info(GOOGLE_ACCESS_TOKEN) + + # Verify response structure (Google's userinfo v2 endpoint) + assert "id" in user_info # Google user ID (v2 endpoint uses 'id' not 'sub') + assert "email" in user_info or "name" in user_info + + print(f"\n✓ Google user info retrieved successfully") + print(f" User ID: {user_info.get('id')}") + print(f" Email: {user_info.get('email', 'N/A')}") + print(f" Name: {user_info.get('name', 'N/A')}") + print(f" Picture: {user_info.get('picture', 'N/A')[:50]}...") + + @skip_google_no_creds + @skip_google_no_token + @pytest.mark.asyncio + async def test_google_authentication_with_token(self): + """ + Test full Google authentication flow with access token. + + This tests the authenticate() method which would normally be + called after the OAuth flow completes. + """ + provider = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + ) + + # Authenticate with access token + result = await provider.authenticate({ + "access_token": GOOGLE_ACCESS_TOKEN, + "scopes": [ + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + ], + }) + + # Verify authentication success + assert result.success is True + assert result.context is not None + assert result.context.authenticated is True + assert result.context.user_id is not None + + # Verify scopes were added + assert len(result.context.scopes) > 0 + + print(f"\n✓ Google authentication successful") + print(f" User ID: {result.context.user_id}") + print(f" Username: {result.context.username}") + print(f" Scopes: {list(result.context.scopes)}") + print(f" Permissions: {[p.name for p in result.context.permissions]}") + + @skip_google_no_creds + @skip_google_no_refresh + @pytest.mark.asyncio + async def test_google_token_refresh(self): + """ + Test refreshing an expired access token. + + Requires: GOOGLE_REFRESH_TOKEN environment variable + + This tests the token refresh flow, which is unique to Google OAuth + (with offline access). + """ + provider = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + ) + + # Refresh the token + token_data = await provider.refresh_access_token(GOOGLE_REFRESH_TOKEN) + + # Verify token response + assert "access_token" in token_data + assert "expires_in" in token_data + assert "token_type" in token_data + + print(f"\n✓ Google token refresh successful") + print(f" New access token: {token_data['access_token'][:20]}...") + print(f" Expires in: {token_data['expires_in']} seconds") + print(f" Token type: {token_data['token_type']}") + + @skip_google_no_creds + @pytest.mark.asyncio + async def test_google_invalid_token_handling(self): + """ + Test that invalid tokens are properly rejected. + + This ensures error handling works correctly. + """ + provider = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + ) + + # Try to authenticate with invalid token + result = await provider.authenticate({ + "access_token": "invalid_token_12345", + "scopes": ["https://www.googleapis.com/auth/userinfo.profile"], + }) + + # Should fail gracefully + assert result.success is False + assert result.error is not None + + print(f"\n✓ Invalid Google token correctly rejected") + print(f" Error: {result.error}") + + @skip_google_no_creds + @pytest.mark.asyncio + async def test_google_invalid_refresh_token_handling(self): + """ + Test that invalid refresh tokens are properly rejected. + """ + provider = GoogleOAuthProvider( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + ) + + # Try to refresh with invalid token + with pytest.raises(ValueError, match="Token refresh failed"): + await provider.refresh_access_token("invalid_refresh_token_12345") + + print(f"\n✓ Invalid refresh token correctly rejected") + + +# ============================================================================ +# USAGE INSTRUCTIONS +# ============================================================================ + +def test_show_setup_instructions(): + """ + Display setup instructions when integration tests are run. + + This is always shown to help users understand what's needed. + """ + print("\n" + "=" * 70) + print("OAUTH INTEGRATION TESTS - SETUP REQUIRED") + print("=" * 70) + print("\nThese tests require OAuth credentials and access tokens.") + print("\nQuick Start:") + print(" 1. See docs/OAUTH_TESTING_SETUP.md for detailed instructions") + print(" 2. Run: python examples/auth/oauth_token_helper.py") + print(" 3. Set environment variables with your tokens") + print("\nRequired Environment Variables:") + print(" GitHub Tests:") + print(" - GITHUB_CLIENT_ID") + print(" - GITHUB_CLIENT_SECRET") + print(" - GITHUB_ACCESS_TOKEN (for authenticated tests)") + print("\n Google Tests:") + print(" - GOOGLE_CLIENT_ID") + print(" - GOOGLE_CLIENT_SECRET") + print(" - GOOGLE_ACCESS_TOKEN (for authenticated tests)") + print(" - GOOGLE_REFRESH_TOKEN (for refresh tests)") + print("\nCurrent Status:") + print(f" GitHub credentials: {'✓ Set' if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET else '✗ Not set'}") + print(f" GitHub token: {'✓ Set' if GITHUB_ACCESS_TOKEN else '✗ Not set'}") + print(f" Google credentials: {'✓ Set' if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET else '✗ Not set'}") + print(f" Google token: {'✓ Set' if GOOGLE_ACCESS_TOKEN else '✗ Not set'}") + print(f" Google refresh: {'✓ Set' if GOOGLE_REFRESH_TOKEN else '✗ Not set'}") + print("\n" + "=" * 70) diff --git a/tests/test_request_middleware.py b/tests/test_request_middleware.py new file mode 100644 index 0000000..1b5d922 --- /dev/null +++ b/tests/test_request_middleware.py @@ -0,0 +1,408 @@ +""" +Tests for Runtime Auth Enforcement Middleware. + +Tests the request-level auth enforcement that validates every request. +""" + +import pytest + +from nextmcp.auth.core import AuthContext, AuthResult +from nextmcp.auth.errors import AuthenticationError, AuthorizationError +from nextmcp.auth.manifest import PermissionManifest +from nextmcp.auth.oauth_providers import GitHubOAuthProvider +from nextmcp.auth.request_middleware import ( + AuthEnforcementMiddleware, + create_auth_middleware, +) +from nextmcp.protocol.auth_metadata import AuthMetadata, AuthRequirement +from nextmcp.session.session_store import MemorySessionStore, SessionData + + +class MockAuthProvider: + """Mock auth provider for testing.""" + + def __init__(self, should_succeed=True, user_id="test_user"): + self.name = "mock" + self.should_succeed = should_succeed + self.user_id = user_id + self.authenticate_called = False + + async def authenticate(self, credentials): + self.authenticate_called = True + + if not self.should_succeed: + return AuthResult.failure("Authentication failed") + + context = AuthContext( + authenticated=True, + user_id=self.user_id, + username="testuser", + ) + context.add_scope("read:data") + return AuthResult.success_result(context) + + +@pytest.mark.asyncio +class TestAuthEnforcementMiddleware: + """Test AuthEnforcementMiddleware functionality.""" + + async def test_no_auth_required_passes_through(self): + """Test request passes through when auth not required.""" + provider = MockAuthProvider() + metadata = AuthMetadata(requirement=AuthRequirement.NONE) + middleware = AuthEnforcementMiddleware(provider=provider, metadata=metadata) + + request = {"method": "tools/call", "params": {"name": "test_tool"}} + handler_called = False + + async def handler(req): + nonlocal handler_called + handler_called = True + return {"success": True} + + result = await middleware(request, handler) + + assert handler_called is True + assert result["success"] is True + assert provider.authenticate_called is False # Should not authenticate + + async def test_optional_auth_without_credentials_passes(self): + """Test optional auth passes without credentials.""" + provider = MockAuthProvider() + metadata = AuthMetadata(requirement=AuthRequirement.OPTIONAL) + middleware = AuthEnforcementMiddleware(provider=provider, metadata=metadata) + + request = {"method": "tools/call", "params": {"name": "test_tool"}} + + async def handler(req): + return {"success": True} + + result = await middleware(request, handler) + + assert result["success"] is True + assert provider.authenticate_called is False + + async def test_required_auth_without_credentials_fails(self): + """Test required auth fails without credentials.""" + provider = MockAuthProvider() + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware(provider=provider, metadata=metadata) + + request = {"method": "tools/call", "params": {"name": "test_tool"}} + + async def handler(req): + return {"success": True} + + with pytest.raises(AuthenticationError, match="no credentials provided"): + await middleware(request, handler) + + async def test_successful_authentication(self): + """Test successful authentication flow.""" + provider = MockAuthProvider(should_succeed=True) + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware(provider=provider, metadata=metadata) + + request = { + "method": "tools/call", + "params": {"name": "test_tool"}, + "auth": {"access_token": "valid_token"}, + } + + async def handler(req): + # Check auth context was injected + assert "_auth_context" in req + assert req["_auth_context"].authenticated is True + return {"success": True} + + result = await middleware(request, handler) + + assert result["success"] is True + assert provider.authenticate_called is True + + async def test_failed_authentication(self): + """Test failed authentication.""" + provider = MockAuthProvider(should_succeed=False) + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware(provider=provider, metadata=metadata) + + request = { + "method": "tools/call", + "params": {"name": "test_tool"}, + "auth": {"access_token": "invalid_token"}, + } + + async def handler(req): + return {"success": True} + + with pytest.raises(AuthenticationError, match="Authentication failed"): + await middleware(request, handler) + + async def test_session_store_integration(self): + """Test integration with session store.""" + provider = MockAuthProvider(should_succeed=True, user_id="user123") + session_store = MemorySessionStore() + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware( + provider=provider, + session_store=session_store, + metadata=metadata, + ) + + # First request should authenticate and create session + request = { + "method": "tools/call", + "params": {"name": "test_tool"}, + "auth": { + "access_token": "token123", + "refresh_token": "refresh123", + }, + } + + async def handler(req): + return {"success": True} + + result = await middleware(request, handler) + assert result["success"] is True + + # Check session was created + session = session_store.load("user123") + assert session is not None + assert session.access_token == "token123" + + async def test_session_reuse(self): + """Test reusing existing session.""" + provider = MockAuthProvider(should_succeed=True, user_id="user123") + session_store = MemorySessionStore() + + # Pre-populate session + session = SessionData( + user_id="user123", + access_token="token123", + scopes=["read:data"], + user_info={"login": "testuser"}, + provider="mock", + ) + session_store.save(session) + + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware( + provider=provider, + session_store=session_store, + metadata=metadata, + ) + + request = { + "method": "tools/call", + "params": {"name": "test_tool"}, + "auth": {"access_token": "token123"}, + } + + async def handler(req): + # Check auth context from session + assert "_auth_context" in req + assert req["_auth_context"].user_id == "user123" + return {"success": True} + + result = await middleware(request, handler) + assert result["success"] is True + + async def test_expired_token_rejection(self): + """Test expired tokens are rejected.""" + import time + + provider = MockAuthProvider() + session_store = MemorySessionStore() + + # Create expired session + session = SessionData( + user_id="user123", + access_token="expired_token", + expires_at=time.time() - 10, # Expired 10 seconds ago + ) + session_store.save(session) + + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware( + provider=provider, + session_store=session_store, + metadata=metadata, + ) + + request = { + "method": "tools/call", + "params": {"name": "test_tool"}, + "auth": {"access_token": "expired_token"}, + } + + async def handler(req): + return {"success": True} + + with pytest.raises(AuthenticationError, match="expired"): + await middleware(request, handler) + + async def test_scope_enforcement(self): + """Test required scopes are enforced.""" + provider = MockAuthProvider(should_succeed=True) + metadata = AuthMetadata( + requirement=AuthRequirement.REQUIRED, + required_scopes=["write:data"], # Requires write scope + ) + middleware = AuthEnforcementMiddleware(provider=provider, metadata=metadata) + + request = { + "method": "tools/call", + "params": {"name": "test_tool"}, + "auth": {"access_token": "token"}, + } + + async def handler(req): + return {"success": True} + + # Provider only gives "read:data" scope, should fail + with pytest.raises(AuthorizationError, match="Missing required scopes"): + await middleware(request, handler) + + async def test_manifest_enforcement(self): + """Test permission manifest is enforced.""" + from nextmcp.auth.errors import ManifestViolationError + + provider = MockAuthProvider(should_succeed=True) + manifest = PermissionManifest() + manifest.define_tool_permission("admin_tool", roles=["admin"]) + + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware( + provider=provider, + metadata=metadata, + manifest=manifest, + ) + + request = { + "method": "tools/call", + "params": {"name": "admin_tool"}, + "auth": {"access_token": "token"}, + } + + async def handler(req): + return {"success": True} + + # User doesn't have admin role, should fail + with pytest.raises(ManifestViolationError): + await middleware(request, handler) + + async def test_manifest_allows_authorized_user(self): + """Test manifest allows authorized user.""" + + class AdminProvider(MockAuthProvider): + async def authenticate(self, credentials): + context = AuthContext( + authenticated=True, + user_id="admin_user", + username="admin", + ) + context.add_role("admin") + return AuthResult.success_result(context) + + provider = AdminProvider() + manifest = PermissionManifest() + manifest.define_tool_permission("admin_tool", roles=["admin"]) + + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware( + provider=provider, + metadata=metadata, + manifest=manifest, + ) + + request = { + "method": "tools/call", + "params": {"name": "admin_tool"}, + "auth": {"access_token": "token"}, + } + + async def handler(req): + return {"success": True} + + result = await middleware(request, handler) + assert result["success"] is True + + async def test_non_tool_request_allowed(self): + """Test non-tool requests are allowed without auth checks.""" + provider = MockAuthProvider() + manifest = PermissionManifest() + manifest.define_tool_permission("protected_tool", roles=["admin"]) + + metadata = AuthMetadata(requirement=AuthRequirement.REQUIRED) + middleware = AuthEnforcementMiddleware( + provider=provider, + metadata=metadata, + manifest=manifest, + ) + + # Request without tool name (e.g., server info request) + request = { + "method": "server/info", + "auth": {"access_token": "token"}, + } + + async def handler(req): + return {"success": True} + + result = await middleware(request, handler) + assert result["success"] is True + + +@pytest.mark.asyncio +class TestCreateAuthMiddleware: + """Test create_auth_middleware helper function.""" + + async def test_create_middleware_with_defaults(self): + """Test creating middleware with default settings.""" + provider = MockAuthProvider() + middleware = create_auth_middleware(provider=provider) + + assert middleware.provider == provider + assert middleware.metadata.requirement == AuthRequirement.REQUIRED + + async def test_create_middleware_with_optional_auth(self): + """Test creating middleware with optional auth.""" + provider = MockAuthProvider() + middleware = create_auth_middleware( + provider=provider, + requirement=AuthRequirement.OPTIONAL, + ) + + assert middleware.metadata.requirement == AuthRequirement.OPTIONAL + + async def test_create_middleware_with_scopes(self): + """Test creating middleware with required scopes.""" + provider = MockAuthProvider() + middleware = create_auth_middleware( + provider=provider, + required_scopes=["read:repo", "write:repo"], + ) + + assert "read:repo" in middleware.metadata.required_scopes + assert "write:repo" in middleware.metadata.required_scopes + + async def test_create_middleware_with_session_store(self): + """Test creating middleware with session store.""" + provider = MockAuthProvider() + session_store = MemorySessionStore() + middleware = create_auth_middleware( + provider=provider, + session_store=session_store, + ) + + assert middleware.session_store == session_store + + async def test_create_middleware_with_manifest(self): + """Test creating middleware with permission manifest.""" + provider = MockAuthProvider() + manifest = PermissionManifest() + middleware = create_auth_middleware( + provider=provider, + manifest=manifest, + ) + + assert middleware.manifest == manifest diff --git a/tests/test_scopes.py b/tests/test_scopes.py new file mode 100644 index 0000000..390831c --- /dev/null +++ b/tests/test_scopes.py @@ -0,0 +1,433 @@ +""" +Tests for OAuth scope system. + +Tests for scope support in AuthContext and scope-based authorization decorators. +""" + +import pytest + +from nextmcp.auth.core import AuthContext, Permission, Role +from nextmcp.auth.middleware import AuthenticationError, requires_auth_async, requires_scope_async + + +class TestAuthContextScopes: + """Tests for scope support in AuthContext.""" + + def test_auth_context_with_scopes(self): + """Test creating AuthContext with scopes.""" + context = AuthContext( + authenticated=True, + user_id="user123", + scopes={"read:data", "write:data"}, + ) + + assert context.authenticated is True + assert context.user_id == "user123" + assert len(context.scopes) == 2 + assert "read:data" in context.scopes + assert "write:data" in context.scopes + + def test_auth_context_default_empty_scopes(self): + """Test that AuthContext has empty scopes by default.""" + context = AuthContext(authenticated=True, user_id="user123") + + assert context.scopes == set() + assert len(context.scopes) == 0 + + def test_has_scope_returns_true_for_existing_scope(self): + """Test has_scope returns True for scopes that exist.""" + context = AuthContext( + authenticated=True, + user_id="user123", + scopes={"read:data", "write:data", "admin:all"}, + ) + + assert context.has_scope("read:data") is True + assert context.has_scope("write:data") is True + assert context.has_scope("admin:all") is True + + def test_has_scope_returns_false_for_missing_scope(self): + """Test has_scope returns False for scopes that don't exist.""" + context = AuthContext(authenticated=True, user_id="user123", scopes={"read:data"}) + + assert context.has_scope("write:data") is False + assert context.has_scope("admin:all") is False + assert context.has_scope("delete:data") is False + + def test_has_scope_case_sensitive(self): + """Test that scope checking is case-sensitive.""" + context = AuthContext(authenticated=True, user_id="user123", scopes={"read:data"}) + + assert context.has_scope("read:data") is True + assert context.has_scope("READ:DATA") is False + assert context.has_scope("Read:Data") is False + + def test_add_scope_single(self): + """Test adding a single scope to AuthContext.""" + context = AuthContext(authenticated=True, user_id="user123") + + context.add_scope("read:data") + + assert context.has_scope("read:data") is True + assert len(context.scopes) == 1 + + def test_add_scope_multiple(self): + """Test adding multiple scopes to AuthContext.""" + context = AuthContext(authenticated=True, user_id="user123") + + context.add_scope("read:data") + context.add_scope("write:data") + context.add_scope("admin:all") + + assert len(context.scopes) == 3 + assert context.has_scope("read:data") is True + assert context.has_scope("write:data") is True + assert context.has_scope("admin:all") is True + + def test_add_scope_duplicate_ignored(self): + """Test that adding duplicate scopes doesn't create duplicates.""" + context = AuthContext(authenticated=True, user_id="user123") + + context.add_scope("read:data") + context.add_scope("read:data") # Duplicate + context.add_scope("read:data") # Duplicate + + assert len(context.scopes) == 1 + assert context.has_scope("read:data") is True + + def test_scopes_and_permissions_coexist(self): + """Test that scopes and permissions can coexist in AuthContext.""" + context = AuthContext( + authenticated=True, + user_id="user123", + permissions={Permission("read:posts"), Permission("write:posts")}, + scopes={"repo:read", "repo:write"}, + ) + + # Check permissions + assert context.has_permission("read:posts") is True + assert context.has_permission("write:posts") is True + + # Check scopes + assert context.has_scope("repo:read") is True + assert context.has_scope("repo:write") is True + + # Verify they're separate + assert context.has_permission("repo:read") is False # Not a permission + assert context.has_scope("read:posts") is False # Not a scope + + def test_scopes_and_roles_coexist(self): + """Test that scopes and roles can coexist in AuthContext.""" + context = AuthContext( + authenticated=True, + user_id="user123", + roles={Role("admin"), Role("editor")}, + scopes={"repo:read", "repo:write"}, + ) + + # Check roles + assert context.has_role("admin") is True + assert context.has_role("editor") is True + + # Check scopes + assert context.has_scope("repo:read") is True + assert context.has_scope("repo:write") is True + + +class MockAuthProvider: + """Mock auth provider for testing.""" + + async def authenticate(self, credentials): + from nextmcp.auth.core import AuthResult + + if credentials.get("valid"): + context = AuthContext( + authenticated=True, + user_id="user123", + scopes=set(credentials.get("scopes", [])), + ) + return AuthResult.success_result(context) + return AuthResult.failure("Invalid credentials") + + +class TestRequiresScopeDecorator: + """Tests for @requires_scope_async decorator.""" + + @pytest.mark.asyncio + async def test_requires_scope_single_scope_success(self): + """Test @requires_scope_async with single scope - success case.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("read:data") + async def protected_function(auth: AuthContext): + return f"Success for {auth.user_id}" + + # Call with valid credentials including required scope + result = await protected_function(auth={"valid": True, "scopes": ["read:data"]}) + + assert result == "Success for user123" + + @pytest.mark.asyncio + async def test_requires_scope_single_scope_failure(self): + """Test @requires_scope_async with single scope - missing scope.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("write:data") + async def protected_function(auth: AuthContext): + return f"Success for {auth.user_id}" + + # Call with valid credentials but missing required scope + with pytest.raises(Exception) as exc_info: + await protected_function(auth={"valid": True, "scopes": ["read:data"]}) + + # Should raise an error about insufficient scopes + assert "scope" in str(exc_info.value).lower() or "permission" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_requires_scope_multiple_scopes_any_matches(self): + """Test @requires_scope_async with multiple scopes - any one matches.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("read:data", "write:data", "admin:all") + async def protected_function(auth: AuthContext): + return f"Success for {auth.user_id}" + + # User has write:data (one of the required scopes) + result = await protected_function(auth={"valid": True, "scopes": ["write:data"]}) + assert result == "Success for user123" + + # User has admin:all (another required scope) + result = await protected_function(auth={"valid": True, "scopes": ["admin:all"]}) + assert result == "Success for user123" + + @pytest.mark.asyncio + async def test_requires_scope_multiple_scopes_none_match(self): + """Test @requires_scope_async with multiple scopes - none match.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("read:data", "write:data", "admin:all") + async def protected_function(auth: AuthContext): + return f"Success for {auth.user_id}" + + # User has different scope + with pytest.raises(Exception): + await protected_function(auth={"valid": True, "scopes": ["other:scope"]}) + + @pytest.mark.asyncio + async def test_requires_scope_with_multiple_user_scopes(self): + """Test @requires_scope_async when user has multiple scopes.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("write:data") + async def protected_function(auth: AuthContext): + return f"Success for {auth.user_id}" + + # User has multiple scopes including the required one + result = await protected_function( + auth={"valid": True, "scopes": ["read:data", "write:data", "admin:all"]} + ) + assert result == "Success for user123" + + @pytest.mark.asyncio + async def test_requires_scope_preserves_function_metadata(self): + """Test that @requires_scope_async preserves function metadata.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("read:data") + async def my_function(auth: AuthContext): + """My function docstring.""" + return "result" + + assert my_function.__name__ == "my_function" + assert my_function.__doc__ == "My function docstring." + + @pytest.mark.asyncio + async def test_requires_scope_without_auth_decorator_fails(self): + """Test that @requires_scope_async requires @requires_auth_async.""" + + @requires_scope_async("read:data") + async def unprotected_function(param: str): + return f"Result: {param}" + + # Should fail because first argument is not AuthContext + with pytest.raises(AuthenticationError, match="requires_scope_async must be used with"): + await unprotected_function("test") + + @pytest.mark.asyncio + async def test_requires_scope_stacking_multiple_decorators(self): + """Test stacking multiple @requires_scope_async decorators.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("read:data") + @requires_scope_async("write:data") + async def protected_function(auth: AuthContext): + return f"Success for {auth.user_id}" + + # User must have both scopes + result = await protected_function( + auth={"valid": True, "scopes": ["read:data", "write:data"]} + ) + assert result == "Success for user123" + + # Missing one scope should fail + with pytest.raises(Exception): + await protected_function(auth={"valid": True, "scopes": ["read:data"]}) + + @pytest.mark.asyncio + async def test_requires_scope_with_sync_function(self): + """Test @requires_scope_async works with sync functions too.""" + provider = MockAuthProvider() + + @requires_auth_async(provider=provider) + @requires_scope_async("read:data") + def sync_protected_function(auth: AuthContext): + return f"Sync success for {auth.user_id}" + + # Wrapper is async even if decorated function is sync + result = await sync_protected_function(auth={"valid": True, "scopes": ["read:data"]}) + assert result == "Sync success for user123" + + +class TestScopeIntegrationWithOAuth: + """Tests for scope integration with OAuth providers.""" + + @pytest.mark.asyncio + async def test_oauth_provider_adds_scopes_to_context(self): + """Test that OAuth providers correctly add scopes to AuthContext.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from nextmcp.auth.oauth import OAuthConfig, OAuthProvider + + class TestOAuthProvider(OAuthProvider): + async def get_user_info(self, access_token): + return {"id": "123", "login": "testuser"} + + def get_additional_auth_params(self): + return {} + + def extract_user_id(self, user_info): + return str(user_info["id"]) + + config = OAuthConfig( + client_id="test_client", + token_url="https://test.com/token", + ) + + provider = TestOAuthProvider(config) + + # Authenticate with scopes + credentials = { + "access_token": "test_token", + "scopes": ["repo:read", "repo:write", "user:email"], + } + + result = await provider.authenticate(credentials) + + assert result.success is True + assert result.context is not None + + # Verify scopes were added as permissions (current behavior) + assert result.context.has_permission("repo:read") is True + assert result.context.has_permission("repo:write") is True + assert result.context.has_permission("user:email") is True + + @pytest.mark.asyncio + async def test_oauth_with_scope_decorator(self): + """Test OAuth authentication with scope-based access control.""" + from unittest.mock import AsyncMock + + from nextmcp.auth.oauth import OAuthConfig, OAuthProvider + + class TestOAuthProvider(OAuthProvider): + async def get_user_info(self, access_token): + return {"id": "123", "login": "testuser"} + + def get_additional_auth_params(self): + return {} + + def extract_user_id(self, user_info): + return str(user_info["id"]) + + config = OAuthConfig(client_id="test_client", token_url="https://test.com/token") + provider = TestOAuthProvider(config) + + # Override authenticate to return context with actual scopes + original_auth = provider.authenticate + + async def auth_with_scopes(credentials): + result = await original_auth(credentials) + if result.success: + # Add scopes to context + for scope in credentials.get("scopes", []): + result.context.add_scope(scope) + return result + + provider.authenticate = auth_with_scopes + + @requires_auth_async(provider=provider) + @requires_scope_async("repo:read") + async def read_repos(auth: AuthContext): + return {"repos": ["repo1", "repo2"], "user": auth.user_id} + + # Test with correct scope + result = await read_repos( + auth={"access_token": "token", "scopes": ["repo:read", "user:email"]} + ) + assert result["user"] == "123" + assert "repos" in result + + # Test without required scope + with pytest.raises(Exception): + await read_repos(auth={"access_token": "token", "scopes": ["user:email"]}) + + +class TestScopeEdgeCases: + """Tests for edge cases in scope handling.""" + + def test_empty_scope_string(self): + """Test handling of empty scope strings.""" + context = AuthContext(authenticated=True, user_id="user123") + + context.add_scope("") + # Empty string should still be added (set behavior) + assert "" in context.scopes + assert context.has_scope("") is True + + def test_scope_with_special_characters(self): + """Test scopes with special characters.""" + context = AuthContext( + authenticated=True, + user_id="user123", + scopes={ + "read:data", + "write:data:all", + "admin:*", + "https://www.googleapis.com/auth/drive", + }, + ) + + assert context.has_scope("read:data") is True + assert context.has_scope("write:data:all") is True + assert context.has_scope("admin:*") is True + assert context.has_scope("https://www.googleapis.com/auth/drive") is True + + def test_scope_immutability_through_set(self): + """Test that scopes set is properly managed.""" + context = AuthContext(authenticated=True, user_id="user123") + + # Add scopes + context.add_scope("scope1") + context.add_scope("scope2") + + # Direct set manipulation should work + context.scopes.add("scope3") + + assert len(context.scopes) == 3 + assert context.has_scope("scope3") is True diff --git a/tests/test_session_store.py b/tests/test_session_store.py new file mode 100644 index 0000000..287510c --- /dev/null +++ b/tests/test_session_store.py @@ -0,0 +1,479 @@ +""" +Tests for Session Store. + +Tests the session storage implementations for OAuth tokens and user data. +""" + +import tempfile +import time +from pathlib import Path + +import pytest + +from nextmcp.session.session_store import ( + FileSessionStore, + MemorySessionStore, + SessionData, +) + + +class TestSessionData: + """Test SessionData functionality.""" + + def test_create_session(self): + """Test creating session data.""" + session = SessionData( + user_id="user123", + access_token="token_abc", + refresh_token="refresh_xyz", + scopes=["profile", "email"], + ) + + assert session.user_id == "user123" + assert session.access_token == "token_abc" + assert session.refresh_token == "refresh_xyz" + assert "profile" in session.scopes + + def test_session_not_expired(self): + """Test session not expired.""" + session = SessionData( + user_id="user123", + access_token="token", + expires_at=time.time() + 3600, # Expires in 1 hour + ) + + assert session.is_expired() is False + + def test_session_expired(self): + """Test session is expired.""" + session = SessionData( + user_id="user123", + access_token="token", + expires_at=time.time() - 10, # Expired 10 seconds ago + ) + + assert session.is_expired() is True + + def test_session_no_expiry(self): + """Test session with no expiry never expires.""" + session = SessionData( + user_id="user123", + access_token="token", + # No expires_at set + ) + + assert session.is_expired() is False + + def test_needs_refresh(self): + """Test checking if token needs refresh.""" + # Token expiring in 2 minutes + session = SessionData( + user_id="user123", + access_token="token", + expires_at=time.time() + 120, + ) + + # Should need refresh (default buffer is 5 minutes) + assert session.needs_refresh() is True + + def test_does_not_need_refresh(self): + """Test token doesn't need refresh yet.""" + # Token expiring in 10 minutes + session = SessionData( + user_id="user123", + access_token="token", + expires_at=time.time() + 600, + ) + + # Should not need refresh (default buffer is 5 minutes) + assert session.needs_refresh() is False + + def test_custom_refresh_buffer(self): + """Test custom refresh buffer.""" + # Token expiring in 2 minutes + session = SessionData( + user_id="user123", + access_token="token", + expires_at=time.time() + 120, + ) + + # With 1 minute buffer, should not need refresh + assert session.needs_refresh(buffer_seconds=60) is False + + # With 3 minute buffer, should need refresh + assert session.needs_refresh(buffer_seconds=180) is True + + def test_session_to_dict(self): + """Test serializing session to dict.""" + session = SessionData( + user_id="user123", + access_token="token_abc", + scopes=["profile", "email"], + provider="google", + ) + + data = session.to_dict() + + assert data["user_id"] == "user123" + assert data["access_token"] == "token_abc" + assert data["scopes"] == ["profile", "email"] + assert data["provider"] == "google" + + def test_session_from_dict(self): + """Test deserializing session from dict.""" + data = { + "user_id": "user123", + "access_token": "token_abc", + "refresh_token": "refresh_xyz", + "scopes": ["profile"], + "provider": "github", + "created_at": 1234567890.0, + "updated_at": 1234567890.0, + "metadata": {}, + "token_type": "Bearer", + "expires_at": None, + "user_info": {}, + } + + session = SessionData.from_dict(data) + + assert session.user_id == "user123" + assert session.access_token == "token_abc" + assert session.refresh_token == "refresh_xyz" + assert "profile" in session.scopes + assert session.provider == "github" + + def test_session_with_user_info(self): + """Test session with user information.""" + session = SessionData( + user_id="user123", + access_token="token", + user_info={ + "name": "John Doe", + "email": "john@example.com", + "avatar": "https://example.com/avatar.jpg", + }, + ) + + assert session.user_info["name"] == "John Doe" + assert session.user_info["email"] == "john@example.com" + + +class TestMemorySessionStore: + """Test MemorySessionStore functionality.""" + + def test_save_and_load_session(self): + """Test saving and loading a session.""" + store = MemorySessionStore() + session = SessionData(user_id="user123", access_token="token_abc") + + store.save(session) + loaded = store.load("user123") + + assert loaded is not None + assert loaded.user_id == "user123" + assert loaded.access_token == "token_abc" + + def test_load_nonexistent_session(self): + """Test loading a session that doesn't exist.""" + store = MemorySessionStore() + loaded = store.load("nonexistent") + + assert loaded is None + + def test_exists(self): + """Test checking if session exists.""" + store = MemorySessionStore() + session = SessionData(user_id="user123", access_token="token") + + assert store.exists("user123") is False + + store.save(session) + + assert store.exists("user123") is True + + def test_delete_session(self): + """Test deleting a session.""" + store = MemorySessionStore() + session = SessionData(user_id="user123", access_token="token") + + store.save(session) + assert store.exists("user123") is True + + deleted = store.delete("user123") + assert deleted is True + assert store.exists("user123") is False + + def test_delete_nonexistent_session(self): + """Test deleting a session that doesn't exist.""" + store = MemorySessionStore() + deleted = store.delete("nonexistent") + + assert deleted is False + + def test_list_users(self): + """Test listing all users.""" + store = MemorySessionStore() + + store.save(SessionData(user_id="user1", access_token="token1")) + store.save(SessionData(user_id="user2", access_token="token2")) + store.save(SessionData(user_id="user3", access_token="token3")) + + users = store.list_users() + + assert len(users) == 3 + assert "user1" in users + assert "user2" in users + assert "user3" in users + + def test_clear_all(self): + """Test clearing all sessions.""" + store = MemorySessionStore() + + store.save(SessionData(user_id="user1", access_token="token1")) + store.save(SessionData(user_id="user2", access_token="token2")) + + count = store.clear_all() + + assert count == 2 + assert len(store.list_users()) == 0 + + def test_update_tokens(self): + """Test updating tokens for existing session.""" + store = MemorySessionStore() + session = SessionData(user_id="user123", access_token="old_token") + store.save(session) + + store.update_tokens( + user_id="user123", + access_token="new_token", + refresh_token="new_refresh", + expires_in=3600, + ) + + updated = store.load("user123") + assert updated.access_token == "new_token" + assert updated.refresh_token == "new_refresh" + assert updated.expires_at is not None + + def test_update_tokens_nonexistent_session(self): + """Test updating tokens for nonexistent session raises error.""" + store = MemorySessionStore() + + with pytest.raises(ValueError, match="No session found"): + store.update_tokens("nonexistent", "token") + + def test_cleanup_expired(self): + """Test cleaning up expired sessions.""" + store = MemorySessionStore() + + # Create sessions: one expired, one valid + store.save( + SessionData( + user_id="expired_user", + access_token="token1", + expires_at=time.time() - 10, # Expired + ) + ) + store.save( + SessionData( + user_id="valid_user", + access_token="token2", + expires_at=time.time() + 3600, # Valid + ) + ) + + count = store.cleanup_expired() + + assert count == 1 + assert store.exists("expired_user") is False + assert store.exists("valid_user") is True + + def test_thread_safety(self): + """Test thread safety of memory store.""" + import threading + + store = MemorySessionStore() + + def save_session(user_id): + session = SessionData(user_id=user_id, access_token=f"token_{user_id}") + store.save(session) + + threads = [threading.Thread(target=save_session, args=(f"user{i}",)) for i in range(10)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # All sessions should be saved + assert len(store.list_users()) == 10 + + +class TestFileSessionStore: + """Test FileSessionStore functionality.""" + + def test_save_and_load_session(self): + """Test saving and loading a session from file.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + session = SessionData(user_id="user123", access_token="token_abc") + + store.save(session) + loaded = store.load("user123") + + assert loaded is not None + assert loaded.user_id == "user123" + assert loaded.access_token == "token_abc" + + def test_persistence_across_instances(self): + """Test sessions persist across store instances.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create session in first store instance + store1 = FileSessionStore(tmpdir) + session = SessionData(user_id="user123", access_token="token_abc") + store1.save(session) + + # Load session in second store instance + store2 = FileSessionStore(tmpdir) + loaded = store2.load("user123") + + assert loaded is not None + assert loaded.user_id == "user123" + assert loaded.access_token == "token_abc" + + def test_file_created(self): + """Test that session file is created.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + session = SessionData(user_id="user123", access_token="token") + + store.save(session) + + # Check file exists + files = list(Path(tmpdir).glob("session_*.json")) + assert len(files) == 1 + + def test_sanitized_filename(self): + """Test that user IDs are sanitized for filenames.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + # User ID with special characters + session = SessionData(user_id="user@email.com", access_token="token") + + store.save(session) + loaded = store.load("user@email.com") + + assert loaded is not None + assert loaded.user_id == "user@email.com" + + def test_load_nonexistent_session(self): + """Test loading nonexistent session returns None.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + loaded = store.load("nonexistent") + + assert loaded is None + + def test_exists(self): + """Test checking if session file exists.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + session = SessionData(user_id="user123", access_token="token") + + assert store.exists("user123") is False + + store.save(session) + + assert store.exists("user123") is True + + def test_delete_session(self): + """Test deleting session file.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + session = SessionData(user_id="user123", access_token="token") + + store.save(session) + deleted = store.delete("user123") + + assert deleted is True + assert store.exists("user123") is False + + def test_list_users(self): + """Test listing users from files.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + + store.save(SessionData(user_id="user1", access_token="token1")) + store.save(SessionData(user_id="user2", access_token="token2")) + store.save(SessionData(user_id="user3", access_token="token3")) + + users = store.list_users() + + assert len(users) == 3 + assert "user1" in users + assert "user2" in users + assert "user3" in users + + def test_clear_all(self): + """Test clearing all session files.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + + store.save(SessionData(user_id="user1", access_token="token1")) + store.save(SessionData(user_id="user2", access_token="token2")) + + count = store.clear_all() + + assert count == 2 + assert len(store.list_users()) == 0 + + def test_cleanup_expired(self): + """Test cleaning up expired session files.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + + # Create sessions: one expired, one valid + store.save( + SessionData( + user_id="expired_user", + access_token="token1", + expires_at=time.time() - 10, + ) + ) + store.save( + SessionData( + user_id="valid_user", + access_token="token2", + expires_at=time.time() + 3600, + ) + ) + + count = store.cleanup_expired() + + assert count == 1 + assert store.exists("expired_user") is False + assert store.exists("valid_user") is True + + def test_directory_creation(self): + """Test that store creates directory if it doesn't exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + session_dir = Path(tmpdir) / "sessions" / "nested" + store = FileSessionStore(session_dir) + + assert session_dir.exists() + assert session_dir.is_dir() + + def test_update_timestamps(self): + """Test that save updates timestamps.""" + with tempfile.TemporaryDirectory() as tmpdir: + store = FileSessionStore(tmpdir) + session = SessionData(user_id="user123", access_token="token") + + original_time = session.updated_at + time.sleep(0.01) # Small delay + + store.save(session) + + loaded = store.load("user123") + assert loaded.updated_at > original_time