From 3d323b1708083477e8a953d5f693519300082a94 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Tue, 6 Jan 2026 21:17:46 +0500 Subject: [PATCH 1/5] mcp server with console plan implementation --- .gitignore | 4 +- IMPLEMENTATION_SUMMARY.md | 344 +++++++++++++++++++++++++++++++ README.md | 107 ++++++++++ VALIDATION_CHECKLIST.md | 292 ++++++++++++++++++++++++++ attach/__init__.py | 7 +- attach/__main__.py | 75 ++++--- attach/cli_claude.py | 141 +++++++++++++ attach/cli_mcp.py | 132 ++++++++++++ attach/gateway.py | 33 ++- audit/__init__.py | 5 + audit/sqlite.py | 315 ++++++++++++++++++++++++++++ auth/oidc.py | 74 ++++--- console/__init__.py | 5 + console/router.py | 131 ++++++++++++ console/static/app.js | 221 ++++++++++++++++++++ console/static/index.html | 90 ++++++++ console/static/style.css | 286 +++++++++++++++++++++++++ examples/agents/coder.py | 17 +- examples/agents/planner.py | 7 +- examples/demo_view_memory.py | 6 +- examples/flask_app/app.py | 19 +- examples/langgraph_demo.py | 41 ++-- main.py | 15 +- mcp/__init__.py | 5 + mcp/config.py | 123 +++++++++++ mcp/proxy.py | 249 ++++++++++++++++++++++ mcp/quota.py | 134 ++++++++++++ mcp/router.py | 90 ++++++++ mem/__init__.py | 7 +- middleware/auth.py | 24 ++- middleware/quota.py | 1 + middleware/session.py | 23 ++- proxy/engine.py | 10 +- pyproject.toml | 2 +- script/dev_login.py | 15 +- tests/test_console_auth.py | 186 +++++++++++++++++ tests/test_jwt_middleware.py | 20 +- tests/test_mcp_optin.py | 144 +++++++++++++ tests/test_mcp_proxy_quota.py | 279 +++++++++++++++++++++++++ tests/test_openmeter_fallback.py | 67 ++++++ tests/test_token_exchange.py | 182 +++++++++------- usage/backends.py | 81 ++++---- usage/factory.py | 14 +- 43 files changed, 3776 insertions(+), 247 deletions(-) create mode 100644 IMPLEMENTATION_SUMMARY.md create mode 100644 VALIDATION_CHECKLIST.md create mode 100644 attach/cli_claude.py create mode 100644 attach/cli_mcp.py create mode 100644 audit/__init__.py create mode 100644 audit/sqlite.py create mode 100644 console/__init__.py create mode 100644 console/router.py create mode 100644 console/static/app.js create mode 100644 console/static/index.html create mode 100644 console/static/style.css create mode 100644 mcp/__init__.py create mode 100644 mcp/config.py create mode 100644 mcp/proxy.py create mode 100644 mcp/quota.py create mode 100644 mcp/router.py create mode 100644 tests/test_console_auth.py create mode 100644 tests/test_mcp_optin.py create mode 100644 tests/test_mcp_proxy_quota.py create mode 100644 tests/test_openmeter_fallback.py diff --git a/.gitignore b/.gitignore index db27f21..e4a619d 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,6 @@ build/ venv/ .venv/ __pyo3*.so -*.egg-info/ \ No newline at end of file +*.egg-info/ +CLAUDE.md +.claude/ \ No newline at end of file diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..dfc89d9 --- /dev/null +++ b/IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,344 @@ +# Plan 1 MCP Gateway Implementation Summary + +## Overview + +Successfully implemented Plan 1 - an OSS-friendly, local-first MCP Gateway layer inside Attach Gateway that is OPT-IN and backwards compatible with existing OIDC/JWT sidecar functionality. + +## Implementation Status: ✅ COMPLETE + +All 17 planned deliverables have been implemented and tested. + +--- + +## Key Features Delivered + +### 1. MCP Configuration & Lifecycle +- **Module**: `mcp/config.py` +- **Config file**: `~/.attach/mcp.json` +- **CLI commands**: + - `attach-gateway mcp list` + - `attach-gateway mcp add [--header ...]` + - `attach-gateway mcp enable/disable ` + - `attach-gateway mcp remove ` +- **Features**: + - HTTP upstream support (stdio spawning not in MVP) + - Header resolution with `env:VARNAME` syntax + - Never logs resolved secrets + +### 2. MCP Reverse Proxy Endpoints +- **Module**: `mcp/router.py`, `mcp/proxy.py` +- **Routes**: + - `POST /mcp/{server}` - Forwards JSON-RPC to configured upstream + - `GET /mcp` - Returns list of servers + enabled state +- **Security**: All `/mcp/*` endpoints require Bearer JWT authentication +- **Error handling**: JSON-RPC error responses for quota/timeout/upstream failures + +### 3. Audit Logging (Local SQLite) +- **Module**: `audit/sqlite.py` +- **Database**: `~/.attach/attach.db` +- **Tables**: + - `mcp_events` - Stores metadata: ts, user, server, method, tool, allowed, latency_ms, error + - `mcp_counters` - Quota usage tracking by date_utc/user/tool +- **Privacy**: No request/response bodies stored, only metadata + +### 4. Quota Enforcement +- **Module**: `mcp/quota.py` +- **Policy file**: `~/.attach/mcp_policy.json` +- **Features**: + - Per-user daily call limits + - Glob pattern matching for tool names (fnmatch) + - Only enforces on `tools/call` JSON-RPC method + - Denied calls return JSON-RPC error (HTTP 200) and are logged + +### 5. Console UI +- **Module**: `console/router.py`, `console/static/` +- **Routes**: + - `GET /console` - Static HTML (unauthenticated, no sensitive data) + - `GET /console/static/*` - Static assets (unauthenticated) + - `GET /console/api/overview` - Statistics (JWT protected) + - `GET /console/api/events` - Event log (JWT protected) + - `GET /console/api/servers` - Server list (JWT protected) +- **Pages**: + - Overview: calls today, denies today, top tools, top users + - Events table: timestamp, user, server, method, tool, allowed, latency, error + - Servers list: enabled servers with URLs +- **Auth model**: Landing page public, API endpoints require JWT + +### 6. Claude Code Installer Helper +- **Module**: `attach/cli_claude.py` +- **Command**: `attach-gateway claude install [--project .] [--bearer ] [--write-file]` +- **Modes**: + - Default: Prints `claude mcp add` commands (recommended, no schema brittleness) + - `--write-file`: Writes `.mcp.json` directly (experimental, schema may change) + - `--bearer`: Includes Authorization header if Claude supports it +- **Safety**: Avoids schema brittleness by preferring command generation + +### 7. OpenMeter Non-Fatal Fix +- **Module**: `usage/factory.py` +- **Change**: `USAGE_METERING=openmeter` without `OPENMETER_API_KEY` now logs warning and falls back to `NullUsageBackend` instead of crashing +- **Impact**: Gateway always starts successfully; missing optional config is non-fatal + +--- + +## Modules Added + +### New Top-Level Packages +- `mcp/` - MCP gateway core functionality + - `__init__.py` + - `config.py` - Configuration management + - `proxy.py` - HTTP forwarding logic + - `router.py` - FastAPI routes + - `quota.py` - Quota enforcement + +- `audit/` - Audit logging + - `__init__.py` + - `sqlite.py` - SQLite-based event logging + +- `console/` - Web console UI + - `__init__.py` + - `router.py` - FastAPI routes + - `static/` + - `index.html` - Single-page console app + - `app.js` - Client-side logic with localStorage JWT + - `style.css` - Styling + +### Modified Core Modules +- `attach/gateway.py` - Conditional MCP/console router mounting (opt-in) +- `attach/__main__.py` - Added subcommand support (backward compatible) +- `attach/cli_mcp.py` - MCP server management commands +- `attach/cli_claude.py` - Claude Code integration helper +- `middleware/auth.py` - Added `/console` and `/console/static/*` exclusions +- `middleware/session.py` - Added `/console` and `/console/static/*` exclusions +- `usage/factory.py` - OpenMeter non-fatal fallback +- `pyproject.toml` - Added `mcp`, `audit`, `console` to packages list + +--- + +## Tests Added + +All tests follow existing patterns from `tests/test_jwt_middleware.py`: + +1. **`tests/test_mcp_optin.py`** + - Tests MCP routes are 404 without opt-in + - Tests MCP routes available with `ATTACH_ENABLE_MCP=true` + - Tests MCP routes available with `~/.attach/mcp.json` present + +2. **`tests/test_mcp_proxy_quota.py`** + - Tests MCP proxy forwards requests to upstream + - Tests quota enforcement denies after limit exceeded + - Tests unknown server returns 404 with JSON-RPC error + +3. **`tests/test_console_auth.py`** + - Tests `/console` accessible without auth + - Tests `/console/static/*` accessible without auth + - Tests `/console/api/*` requires JWT (401 without, 200 with valid token) + +4. **`tests/test_openmeter_fallback.py`** + - Tests `USAGE_METERING=openmeter` without key doesn't crash + - Tests falls back to `NullUsageBackend` + - Tests `USAGE_METERING=openmeter` with key uses `OpenMeterBackend` + +--- + +## Opt-In Mechanism + +MCP Gateway is enabled if either: +1. Environment variable `ATTACH_ENABLE_MCP=true` is set, OR +2. Config file `~/.attach/mcp.json` exists + +When disabled: +- `/mcp` and `/console` routes are NOT mounted (404) +- No MCP-related imports or initialization +- Zero performance impact on core OIDC sidecar functionality + +--- + +## Security Model + +### Authentication Requirements +| Path | Auth Required | Notes | +|------|---------------|-------| +| `/mcp/*` | ✅ Yes | Bearer JWT required | +| `/console` | ❌ No | Static HTML, no sensitive data | +| `/console/static/*` | ❌ No | CSS/JS/images only | +| `/console/api/*` | ✅ Yes | Bearer JWT required | + +### Privacy & Security Features +- JWT validation uses existing OIDC/DID infrastructure +- Client `Authorization` header NOT forwarded to upstream MCP servers +- Upstream headers configured separately with `env:` support +- Audit logs contain metadata only (no bodies) +- All data local by default (no phone-home) +- User sub logged as first 8 chars only in some places + +--- + +## Backward Compatibility + +### ✅ Guaranteed Safe +- Default behavior unchanged: `attach-gateway --port 8080` works exactly as before +- Existing routes unaffected: `/api/chat`, `/a2a/*`, `/mem/*` unchanged +- Auth middleware preserves exact behavior for existing paths +- No new required dependencies +- No breaking changes to environment variables + +### CLI Changes (Backward Compatible) +- `attach-gateway --port 8080` still runs server (default command) +- New subcommands: `attach-gateway mcp ...`, `attach-gateway claude ...` +- Uses `click.Group(invoke_without_command=True)` pattern + +--- + +## README Updates + +Added comprehensive section: "Claude Code + MCP Gateway (Local-First, 2-Minute Setup)" + +Includes: +- Feature overview and benefits +- Quick setup instructions +- CLI command examples +- Claude Code integration guide +- Console UI usage +- How it works diagram +- Security notes + +Location: Between main quickstart and "Use in your project" sections + +--- + +## Definition of Done - Verification + +✅ **Default behavior unchanged** - MCP disabled by default +✅ **MCP explicitly enabled** - Via env var or config file +✅ **Routes mounted conditionally** - `/mcp` and `/console` only when enabled +✅ **JWT protection maintained** - No auth weakening for `/mcp` +✅ **Console auth model secure** - Public landing page, protected API +✅ **Audit logs working** - SQLite metadata storage +✅ **Quota enforcement working** - Glob patterns, daily limits, JSON-RPC errors +✅ **Claude installer helper** - Prints valid commands +✅ **OpenMeter never crashes** - Graceful fallback to NullUsageBackend +✅ **Tests written** - 4 test files covering all features +✅ **README updated** - Stars-magnet quickstart section added +✅ **No new deps** - Uses stdlib + existing FastAPI/httpx/click +✅ **Syntax validated** - All modules compile successfully + +--- + +## Known Limitations (As Designed for MVP) + +1. **HTTP transport only** - No stdio process spawning yet +2. **Single gateway instance** - Quota counters are local (not distributed) +3. **UTC day boundary** - Quota resets at midnight UTC regardless of timezone +4. **No response body caching** - Audit logs store metadata only +5. **Basic glob patterns** - Uses fnmatch, not full regex +6. **No rate limiting** - Only daily quotas, no per-minute throttling + +--- + +## Next Steps (Future Work, Not in Plan 1) + +1. **Stdio transport** - Spawn MCP servers as child processes +2. **Distributed quota counters** - Redis/Postgres backend for multi-instance +3. **Response caching** - Optional LLM response memoization +4. **Advanced rate limiting** - Per-minute/hour sliding windows +5. **Webhook notifications** - Alert on quota exceeded +6. **RBAC policies** - Role-based tool access control +7. **Prometheus metrics** - MCP-specific Prometheus exports + +--- + +## File Manifest + +### New Files +``` +mcp/__init__.py +mcp/config.py +mcp/proxy.py +mcp/quota.py +mcp/router.py +audit/__init__.py +audit/sqlite.py +console/__init__.py +console/router.py +console/static/index.html +console/static/app.js +console/static/style.css +attach/cli_mcp.py +attach/cli_claude.py +tests/test_mcp_optin.py +tests/test_mcp_proxy_quota.py +tests/test_console_auth.py +tests/test_openmeter_fallback.py +IMPLEMENTATION_SUMMARY.md +``` + +### Modified Files +``` +attach/gateway.py +attach/__main__.py +middleware/auth.py +middleware/session.py +usage/factory.py +pyproject.toml +README.md +``` + +--- + +## Commands Reference + +### Start Gateway with MCP +```bash +export OIDC_ISSUER=https://your-domain.auth0.com/ +export OIDC_AUD=your-api-identifier +export ATTACH_ENABLE_MCP=true + +attach-gateway --port 8080 +``` + +### Configure MCP Server +```bash +attach-gateway mcp add notion http://localhost:7001/mcp \ + --header "Authorization: env:NOTION_TOKEN" + +attach-gateway mcp enable notion +attach-gateway mcp list +``` + +### Setup Quota Policy +```bash +cat > ~/.attach/mcp_policy.json < ~/.attach/mcp.json < ~/.attach/mcp_policy.json < +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path + +import click + +from mcp.config import get_enabled_servers + + +@click.group(name="claude") +def claude_group(): + """Claude Code integration helpers.""" + pass + + +@claude_group.command(name="install") +@click.option( + "--project", default=".", help="Project directory (default: current directory)" +) +@click.option("--bearer", help="Bearer token for Authorization header (optional)") +@click.option( + "--write-file", is_flag=True, help="Write .mcp.json file directly (experimental)" +) +def install_claude(project: str, bearer: str, write_file: bool): + """ + Generate Claude Code MCP configuration. + + This command helps integrate Attach Gateway MCP servers with Claude Code. + + Default mode: Prints 'claude mcp add' commands to run. + --write-file mode: Writes /.mcp.json directly (experimental, schema may be outdated). + """ + enabled = get_enabled_servers() + + if not enabled: + click.echo( + "No enabled MCP servers found. Use 'attach-gateway mcp add' to configure servers first.", + err=True, + ) + return + + # Get gateway URL from environment or default + gateway_url = os.getenv("ATTACH_GATEWAY_URL", "http://localhost:8080") + + if write_file: + _write_mcp_json_file(project, enabled, gateway_url, bearer) + else: + _print_claude_commands(enabled, gateway_url, bearer) + + +def _print_claude_commands(servers: dict, gateway_url: str, bearer: str): + """Print 'claude mcp add' commands for each server.""" + click.echo("Run the following commands to add MCP servers to Claude Code:\n") + + for server_name in servers.keys(): + server_url = f"{gateway_url}/mcp/{server_name}" + + if bearer: + # Check if Claude Code supports headers in HTTP transport + click.echo(f"# Note: If Claude Code HTTP transport supports headers:") + click.echo( + f'claude mcp add --transport http --name "{server_name}" --url "{server_url}" --header "Authorization: Bearer {bearer}"' + ) + click.echo() + click.echo( + f"# If headers are not supported, you'll need to configure auth separately or use a different approach." + ) + else: + click.echo( + f'claude mcp add --transport http --name "{server_name}" --url "{server_url}"' + ) + + click.echo() + + if not bearer: + click.echo( + "Tip: Pass --bearer to include Authorization header in commands (if supported by Claude)." + ) + + +def _write_mcp_json_file( + project_dir: str, servers: dict, gateway_url: str, bearer: str +): + """ + Write .mcp.json file for Claude Code (experimental). + + WARNING: This assumes a specific schema that may change. + Prefer using 'claude mcp add' commands instead. + """ + project_path = Path(project_dir).resolve() + mcp_json_path = project_path / ".mcp.json" + + if mcp_json_path.exists(): + if not click.confirm(f"{mcp_json_path} already exists. Overwrite?"): + click.echo("Aborted.") + return + + # Build .mcp.json structure (this is experimental and may not match Claude's schema) + mcp_config = {"mcpServers": {}} + + for server_name in servers.keys(): + server_url = f"{gateway_url}/mcp/{server_name}" + + server_config = { + "transport": "http", + "url": server_url, + } + + if bearer: + # This may not be the correct schema - Claude Code's HTTP transport may not support headers + server_config["headers"] = {"Authorization": f"Bearer {bearer}"} + click.echo( + f"Warning: Added Authorization header, but Claude Code HTTP transport may not support headers.", + err=True, + ) + + mcp_config["mcpServers"][server_name] = server_config + + try: + with open(mcp_json_path, "w", encoding="utf-8") as f: + json.dump(mcp_config, f, indent=2) + + click.echo(f"Wrote MCP configuration to {mcp_json_path}") + click.echo() + click.echo( + "Note: This is experimental. The schema may not match Claude Code's requirements." + ) + click.echo( + "Prefer using 'claude mcp add' commands for guaranteed compatibility." + ) + + except IOError as exc: + click.echo(f"Error writing {mcp_json_path}: {exc}", err=True) diff --git a/attach/cli_mcp.py b/attach/cli_mcp.py new file mode 100644 index 0000000..7d2ae35 --- /dev/null +++ b/attach/cli_mcp.py @@ -0,0 +1,132 @@ +""" +CLI commands for MCP server management. + +attach-gateway mcp list +attach-gateway mcp add +attach-gateway mcp enable +attach-gateway mcp disable +attach-gateway mcp remove +""" + +from __future__ import annotations + +import click + +from mcp.config import ( + get_mcp_config_path, + load_mcp_config, + save_mcp_config, +) + + +@click.group(name="mcp") +def mcp_group(): + """Manage MCP servers.""" + pass + + +@mcp_group.command(name="list") +def list_servers(): + """List all configured MCP servers.""" + config = load_mcp_config() + servers = config.get("servers", {}) + + if not servers: + click.echo("No MCP servers configured.") + click.echo(f"Config file: {get_mcp_config_path()}") + return + + click.echo(f"MCP servers ({get_mcp_config_path()}):\n") + for name, server_config in servers.items(): + enabled = server_config.get("enabled", False) + url = server_config.get("url", "") + status = "✓ enabled" if enabled else "✗ disabled" + click.echo(f" {name:20} {status:12} {url}") + + +@mcp_group.command(name="add") +@click.argument("name") +@click.argument("url") +@click.option( + "--header", multiple=True, help="Header in format 'Key: Value' or 'Key: env:VAR'" +) +def add_server(name: str, url: str, header: tuple[str, ...]): + """Add a new MCP server.""" + config = load_mcp_config() + + if name in config.get("servers", {}): + if not click.confirm(f"Server '{name}' already exists. Overwrite?"): + click.echo("Aborted.") + return + + # Parse headers + headers = {} + for h in header: + if ": " not in h: + click.echo(f"Warning: Invalid header format '{h}', skipping", err=True) + continue + key, value = h.split(": ", 1) + headers[key] = value + + config.setdefault("servers", {})[name] = { + "enabled": True, + "url": url, + } + + if headers: + config["servers"][name]["headers"] = headers + + save_mcp_config(config) + click.echo(f"Added MCP server '{name}'") + click.echo(f" URL: {url}") + if headers: + click.echo(f" Headers: {list(headers.keys())}") + + +@mcp_group.command(name="enable") +@click.argument("name") +def enable_server(name: str): + """Enable an MCP server.""" + config = load_mcp_config() + + if name not in config.get("servers", {}): + click.echo(f"Error: Server '{name}' not found", err=True) + return + + config["servers"][name]["enabled"] = True + save_mcp_config(config) + click.echo(f"Enabled MCP server '{name}'") + + +@mcp_group.command(name="disable") +@click.argument("name") +def disable_server(name: str): + """Disable an MCP server.""" + config = load_mcp_config() + + if name not in config.get("servers", {}): + click.echo(f"Error: Server '{name}' not found", err=True) + return + + config["servers"][name]["enabled"] = False + save_mcp_config(config) + click.echo(f"Disabled MCP server '{name}'") + + +@mcp_group.command(name="remove") +@click.argument("name") +def remove_server(name: str): + """Remove an MCP server.""" + config = load_mcp_config() + + if name not in config.get("servers", {}): + click.echo(f"Error: Server '{name}' not found", err=True) + return + + if not click.confirm(f"Remove server '{name}'?"): + click.echo("Aborted.") + return + + del config["servers"][name] + save_mcp_config(config) + click.echo(f"Removed MCP server '{name}'") diff --git a/attach/gateway.py b/attach/gateway.py index 855ca9f..7cee8a1 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -9,12 +9,13 @@ import weaviate from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel +from starlette.middleware.base import BaseHTTPMiddleware +import logs from a2a.routes import router as a2a_router from auth.oidc import _require_env -import logs + logs_router = logs.router from mem import get_memory_backend from middleware.auth import jwt_auth_mw @@ -27,6 +28,7 @@ # Guard TokenQuotaMiddleware import (matches main.py pattern) try: from middleware.quota import TokenQuotaMiddleware + QUOTA_AVAILABLE = True except ImportError: # optional extra not installed QUOTA_AVAILABLE = False @@ -113,11 +115,17 @@ async def lifespan(app: FastAPI): backend_selector = _select_backend() app.state.usage = get_usage_backend(backend_selector) mount_metrics(app) - + + # Initialize MCP audit DB if MCP is enabled + if getattr(app.state, "mcp_enabled", False): + from audit.sqlite import init_db + + init_db() + yield - + # Shutdown - if hasattr(app.state.usage, 'aclose'): + if hasattr(app.state.usage, "aclose"): await app.state.usage.aclose() @@ -170,7 +178,7 @@ async def auth_config(): allow_headers=["*"], allow_credentials=True, ) - + # Only add quota middleware if available and explicitly configured limit = int_env("MAX_TOKENS_PER_MIN", 60000) if QUOTA_AVAILABLE and limit is not None: @@ -185,6 +193,19 @@ async def auth_config(): app.include_router(logs_router) app.include_router(mem_router) + # Conditionally mount MCP and console routers (opt-in) + from mcp.config import is_mcp_enabled + + mcp_enabled = is_mcp_enabled() + app.state.mcp_enabled = mcp_enabled + + if mcp_enabled: + from console.router import router as console_router + from mcp.router import router as mcp_router + + app.include_router(mcp_router) + app.include_router(console_router) + # Setup memory backend memory_backend = get_memory_backend(config.mem_backend, config) app.state.memory = memory_backend diff --git a/audit/__init__.py b/audit/__init__.py new file mode 100644 index 0000000..ac3418f --- /dev/null +++ b/audit/__init__.py @@ -0,0 +1,5 @@ +""" +Audit logging for MCP events +""" + +from __future__ import annotations diff --git a/audit/sqlite.py b/audit/sqlite.py new file mode 100644 index 0000000..8e1ae8d --- /dev/null +++ b/audit/sqlite.py @@ -0,0 +1,315 @@ +""" +SQLite-based audit log for MCP events. + +Database: ~/.attach/attach.db + +Schema: + mcp_events( + id INTEGER PRIMARY KEY, + ts REAL, + user TEXT, + server TEXT, + method TEXT, + tool TEXT, + allowed INTEGER, + latency_ms REAL, + error TEXT + ) + + mcp_counters( + date_utc TEXT, + user TEXT, + tool TEXT, + count INTEGER, + PRIMARY KEY(date_utc, user, tool) + ) +""" + +from __future__ import annotations + +import logging +import sqlite3 +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +from mcp.config import get_attach_dir + +log = logging.getLogger(__name__) + + +def get_db_path() -> Path: + """Return path to audit database.""" + return get_attach_dir() / "attach.db" + + +def init_db() -> None: + """Initialize audit database schema.""" + db_path = get_db_path() + conn = sqlite3.connect(db_path) + try: + cursor = conn.cursor() + + # MCP events table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS mcp_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL NOT NULL, + user TEXT NOT NULL, + server TEXT NOT NULL, + method TEXT, + tool TEXT, + allowed INTEGER NOT NULL, + latency_ms REAL, + error TEXT + ) + """ + ) + + # Index for queries + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_mcp_events_ts + ON mcp_events(ts DESC) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_mcp_events_user + ON mcp_events(user) + """ + ) + + # MCP quota counters table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS mcp_counters ( + date_utc TEXT NOT NULL, + user TEXT NOT NULL, + tool TEXT NOT NULL, + count INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY(date_utc, user, tool) + ) + """ + ) + + conn.commit() + finally: + conn.close() + + +def insert_mcp_event( + user: str, + server: str, + method: Optional[str], + tool: Optional[str], + allowed: bool, + latency_ms: Optional[float] = None, + error: Optional[str] = None, +) -> None: + """Insert an MCP event into audit log.""" + db_path = get_db_path() + try: + conn = sqlite3.connect(db_path) + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO mcp_events (ts, user, server, method, tool, allowed, latency_ms, error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + time.time(), + user, + server, + method, + tool, + int(allowed), + latency_ms, + error, + ), + ) + conn.commit() + finally: + conn.close() + except sqlite3.Error as exc: + log.error("Failed to insert MCP event: %s", exc) + + +def query_mcp_events( + limit: int = 200, + user: Optional[str] = None, + server: Optional[str] = None, +) -> list[dict[str, Any]]: + """Query MCP events from audit log.""" + db_path = get_db_path() + if not db_path.exists(): + return [] + + try: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + cursor = conn.cursor() + + query = "SELECT * FROM mcp_events WHERE 1=1" + params: list[Any] = [] + + if user: + query += " AND user = ?" + params.append(user) + if server: + query += " AND server = ?" + params.append(server) + + query += " ORDER BY ts DESC LIMIT ?" + params.append(limit) + + cursor.execute(query, params) + rows = cursor.fetchall() + + return [dict(row) for row in rows] + finally: + conn.close() + except sqlite3.Error as exc: + log.error("Failed to query MCP events: %s", exc) + return [] + + +def overview_stats() -> dict[str, Any]: + """ + Return overview statistics for console dashboard. + + Returns: + { + "calls_today": int, + "denies_today": int, + "top_tools": [(tool, count), ...], + "top_users": [(user, count), ...] + } + """ + db_path = get_db_path() + if not db_path.exists(): + return { + "calls_today": 0, + "denies_today": 0, + "top_tools": [], + "top_users": [], + } + + try: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + cursor = conn.cursor() + + # Calculate today's midnight UTC timestamp + now_utc = datetime.now(timezone.utc) + today_midnight = datetime( + now_utc.year, now_utc.month, now_utc.day, tzinfo=timezone.utc + ) + today_ts = today_midnight.timestamp() + + # Calls today + cursor.execute( + "SELECT COUNT(*) as cnt FROM mcp_events WHERE ts >= ?", (today_ts,) + ) + calls_today = cursor.fetchone()["cnt"] + + # Denies today + cursor.execute( + "SELECT COUNT(*) as cnt FROM mcp_events WHERE ts >= ? AND allowed = 0", + (today_ts,), + ) + denies_today = cursor.fetchone()["cnt"] + + # Top tools (all time, limited to avoid huge results) + cursor.execute( + """ + SELECT tool, COUNT(*) as cnt + FROM mcp_events + WHERE tool IS NOT NULL + GROUP BY tool + ORDER BY cnt DESC + LIMIT 10 + """ + ) + top_tools = [(row["tool"], row["cnt"]) for row in cursor.fetchall()] + + # Top users (all time) + cursor.execute( + """ + SELECT user, COUNT(*) as cnt + FROM mcp_events + GROUP BY user + ORDER BY cnt DESC + LIMIT 10 + """ + ) + top_users = [(row["user"], row["cnt"]) for row in cursor.fetchall()] + + return { + "calls_today": calls_today, + "denies_today": denies_today, + "top_tools": top_tools, + "top_users": top_users, + } + finally: + conn.close() + except sqlite3.Error as exc: + log.error("Failed to compute overview stats: %s", exc) + return { + "calls_today": 0, + "denies_today": 0, + "top_tools": [], + "top_users": [], + } + + +def get_quota_count(user: str, tool: str, date_utc: str) -> int: + """Get current quota count for a user/tool/day.""" + db_path = get_db_path() + if not db_path.exists(): + return 0 + + try: + conn = sqlite3.connect(db_path) + try: + cursor = conn.cursor() + cursor.execute( + "SELECT count FROM mcp_counters WHERE date_utc = ? AND user = ? AND tool = ?", + (date_utc, user, tool), + ) + row = cursor.fetchone() + return row[0] if row else 0 + finally: + conn.close() + except sqlite3.Error as exc: + log.error("Failed to get quota count: %s", exc) + return 0 + + +def increment_quota_count(user: str, tool: str, date_utc: str) -> None: + """Increment quota counter for a user/tool/day.""" + db_path = get_db_path() + try: + conn = sqlite3.connect(db_path) + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO mcp_counters (date_utc, user, tool, count) + VALUES (?, ?, ?, 1) + ON CONFLICT(date_utc, user, tool) + DO UPDATE SET count = count + 1 + """, + (date_utc, user, tool), + ) + conn.commit() + finally: + conn.close() + except sqlite3.Error as exc: + log.error("Failed to increment quota count: %s", exc) diff --git a/auth/oidc.py b/auth/oidc.py index c2d4ec6..d3c1cca 100644 --- a/auth/oidc.py +++ b/auth/oidc.py @@ -6,9 +6,9 @@ from typing import Any import httpx +from dotenv import load_dotenv from jose import jwt -from dotenv import load_dotenv load_dotenv() # --------------------------------------------------------------------------- @@ -61,7 +61,7 @@ def _get_jwks_url(issuer: str) -> str: return f"https://api.descope.com/{project_id}/.well-known/jwks.json" else: if "api.descope.com/v1/apps/" in issuer: - project_id = issuer.split("/")[-1] + project_id = issuer.split("/")[-1] return f"https://api.descope.com/{project_id}/.well-known/jwks.json" else: base_url = issuer.rstrip("/") @@ -104,16 +104,16 @@ async def _exchange_jwt_descope( descope_client_id = _require_env("DESCOPE_CLIENT_ID") descope_client_secret = _require_env("DESCOPE_CLIENT_SECRET") - token_endpoint = f"{descope_base_url}/oauth2/v1/apps/token" + token_endpoint = f"{descope_base_url}/oauth2/v1/apps/token" grant_data = { - "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": external_jwt, "client_id": descope_client_id, "client_secret": descope_client_secret, "issuer": external_issuer, } - + async with httpx.AsyncClient() as client: response = await client.post( token_endpoint, @@ -177,7 +177,10 @@ def _verify_jwt_direct(token: str, *, leeway: int = 60) -> dict[str, Any]: }, ) -def _verify_jwt_against(token: str, issuer: str, *, audience: str, leeway: int = 60) -> dict[str, Any]: + +def _verify_jwt_against( + token: str, issuer: str, *, audience: str, leeway: int = 60 +) -> dict[str, Any]: header = jwt.get_unverified_header(token) alg = header.get("alg") if alg not in ACCEPTED_ALGS: @@ -196,18 +199,27 @@ def _verify_jwt_against(token: str, issuer: str, *, audience: str, leeway: int = raise ValueError("signing key not found in issuer JWKS") return jwt.decode( - token, key_cfg, algorithms=[alg], - audience=audience, issuer=issuer, - options={"leeway": leeway, "verify_aud": True, "verify_exp": True, "verify_iat": True}, + token, + key_cfg, + algorithms=[alg], + audience=audience, + issuer=issuer, + options={ + "leeway": leeway, + "verify_aud": True, + "verify_exp": True, + "verify_iat": True, + }, ) + async def verify_jwt_with_exchange(token: str, *, leeway: int = 60) -> dict[str, Any]: """ Exchange an external JWT for a Descope token and verify it. First, tries to directly verify the JWT. Then, immediately throws on validation errors without attempting the exchange. - Attempts the exchange. + Attempts the exchange. Returns: Decoded claim set (`dict[str, Any]`) on success. @@ -216,44 +228,52 @@ async def verify_jwt_with_exchange(token: str, *, leeway: int = 60) -> dict[str, ValueError | jose.JWTError on any validation error. ValueError if exchange fails. ValueError if exchange is not applicable (e.g., missing issuer). - """ + """ try: return _verify_jwt_direct(token, leeway=leeway) except ValueError as direct_error: # Don't attempt exchange for validation errors like invalid algorithm or missing kid - if any(phrase in str(direct_error) for phrase in [ - "not allowed", - "missing 'kid'", - "invalid token", - "malformed", - "expired" - ]): + if any( + phrase in str(direct_error) + for phrase in [ + "not allowed", + "missing 'kid'", + "invalid token", + "malformed", + "expired", + ] + ): raise direct_error - + try: unverified_claims = jwt.get_unverified_claims(token) external_issuer = unverified_claims.get("iss") - + if not external_issuer: raise ValueError("Cannot extract issuer from token for exchange") - + descope_token = await _exchange_jwt_descope(token, external_issuer) - descope_issuer = f"https://api.descope.com/v1/apps/{_require_env('DESCOPE_PROJECT_ID')}" + descope_issuer = ( + f"https://api.descope.com/v1/apps/{_require_env('DESCOPE_PROJECT_ID')}" + ) audience = os.getenv("DESCOPE_AUD", _get_oidc_audience()) - return _verify_jwt_against(descope_token, issuer=descope_issuer, audience=audience, leeway=leeway) + return _verify_jwt_against( + descope_token, issuer=descope_issuer, audience=audience, leeway=leeway + ) except Exception as exchange_error: - raise ValueError(f"JWT verification failed; direct={direct_error!s}; exchange={exchange_error!s}") + raise ValueError( + f"JWT verification failed; direct={direct_error!s}; exchange={exchange_error!s}" + ) except Exception as other_error: raise other_error - # --------------------------------------------------------------------------- # # Public API # # --------------------------------------------------------------------------- # def verify_jwt(token: str, *, leeway: int = 60) -> dict[str, Any]: """ - Backward-compatible JWT verification + Backward-compatible JWT verification async structure with exchange can be called with verify_jwt_with_exchange """ - return _verify_jwt_direct(token, leeway=leeway) \ No newline at end of file + return _verify_jwt_direct(token, leeway=leeway) diff --git a/console/__init__.py b/console/__init__.py new file mode 100644 index 0000000..1634cd2 --- /dev/null +++ b/console/__init__.py @@ -0,0 +1,5 @@ +""" +Local web console for MCP Gateway +""" + +from __future__ import annotations diff --git a/console/router.py b/console/router.py new file mode 100644 index 0000000..b81b062 --- /dev/null +++ b/console/router.py @@ -0,0 +1,131 @@ +""" +FastAPI router for console UI. + +Routes: + GET /console - Serve console HTML (unauthenticated) + GET /console/static/* - Serve static assets (unauthenticated) + GET /console/api/overview - Overview stats (JWT protected) + GET /console/api/events - Recent events (JWT protected) + GET /console/api/servers - Server list (JWT protected) + +Security: + - /console and /console/static/* are unauthenticated (no sensitive data) + - /console/api/* requires JWT Bearer token +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Optional + +from fastapi import APIRouter, Request +from fastapi.responses import FileResponse, JSONResponse + +from audit.sqlite import overview_stats, query_mcp_events +from mcp.config import get_enabled_servers + +log = logging.getLogger(__name__) + +router = APIRouter(prefix="/console", tags=["console"]) + +# Path to console static files +STATIC_DIR = Path(__file__).parent / "static" + + +@router.get("") +async def serve_console(): + """Serve console HTML page (unauthenticated).""" + index_path = STATIC_DIR / "index.html" + if not index_path.exists(): + return JSONResponse(status_code=404, content={"detail": "Console UI not found"}) + return FileResponse(index_path, media_type="text/html") + + +@router.get("/static/{filename}") +async def serve_static(filename: str): + """Serve static assets (unauthenticated).""" + # Security: only allow specific file extensions + allowed_extensions = {".js", ".css", ".html", ".png", ".svg", ".ico"} + file_path = STATIC_DIR / filename + + if not file_path.exists() or file_path.suffix not in allowed_extensions: + return JSONResponse(status_code=404, content={"detail": "File not found"}) + + # Prevent directory traversal + if not str(file_path.resolve()).startswith(str(STATIC_DIR.resolve())): + return JSONResponse(status_code=403, content={"detail": "Access denied"}) + + return FileResponse(file_path) + + +@router.get("/api/overview") +async def get_overview(request: Request): + """ + Get overview statistics for dashboard (JWT protected). + + Returns: + { + "calls_today": int, + "denies_today": int, + "top_tools": [[tool, count], ...], + "top_users": [[user, count], ...] + } + """ + # Auth middleware ensures request.state.sub exists + stats = overview_stats() + return stats + + +@router.get("/api/events") +async def get_events( + request: Request, + limit: int = 200, + user: Optional[str] = None, + server: Optional[str] = None, +): + """ + Get recent MCP events (JWT protected). + + Query params: + - limit: Max number of events (default 200) + - user: Filter by user (optional) + - server: Filter by server (optional) + + Returns: + { + "events": [...] + } + """ + # Auth middleware ensures request.state.sub exists + events = query_mcp_events(limit=limit, user=user, server=server) + return {"events": events} + + +@router.get("/api/servers") +async def get_servers(request: Request): + """ + Get list of MCP servers (JWT protected). + + Returns: + { + "servers": { + "server_name": { + "enabled": true, + "url": "http://..." + } + } + } + """ + # Auth middleware ensures request.state.sub exists + enabled = get_enabled_servers() + + # Return sanitized view (no headers) + servers = {} + for name, config in enabled.items(): + servers[name] = { + "enabled": config.get("enabled", False), + "url": config.get("url", ""), + } + + return {"servers": servers} diff --git a/console/static/app.js b/console/static/app.js new file mode 100644 index 0000000..e6eb686 --- /dev/null +++ b/console/static/app.js @@ -0,0 +1,221 @@ +// Attach Gateway MCP Console - Client-side JavaScript + +const AUTH_TOKEN_KEY = 'attach_mcp_token'; + +// Token management +function getToken() { + return localStorage.getItem(AUTH_TOKEN_KEY); +} + +function setToken(token) { + localStorage.setItem(AUTH_TOKEN_KEY, token); +} + +function clearToken() { + localStorage.removeItem(AUTH_TOKEN_KEY); +} + +// API helpers +async function fetchAPI(path, options = {}) { + const token = getToken(); + if (!token) { + throw new Error('No authentication token'); + } + + const headers = { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + ...options.headers, + }; + + const response = await fetch(path, { + ...options, + headers, + }); + + if (response.status === 401) { + clearToken(); + showLoginPrompt(); + throw new Error('Authentication failed'); + } + + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + + return response.json(); +} + +// UI state management +function showLoginPrompt() { + document.getElementById('login-prompt').style.display = 'block'; + document.getElementById('dashboard').style.display = 'none'; + document.getElementById('auth-status').textContent = ''; +} + +function showDashboard() { + document.getElementById('login-prompt').style.display = 'none'; + document.getElementById('dashboard').style.display = 'block'; + document.getElementById('auth-status').textContent = 'Authenticated'; + loadOverview(); +} + +function switchView(viewName) { + document.querySelectorAll('.view').forEach(v => v.style.display = 'none'); + document.querySelectorAll('.nav-btn').forEach(b => b.classList.remove('active')); + + const viewElement = document.getElementById(`${viewName}-view`); + if (viewElement) { + viewElement.style.display = 'block'; + } + + const navButton = document.querySelector(`.nav-btn[data-view="${viewName}"]`); + if (navButton) { + navButton.classList.add('active'); + } + + if (viewName === 'overview') { + loadOverview(); + } else if (viewName === 'events') { + loadEvents(); + } else if (viewName === 'servers') { + loadServers(); + } +} + +// Data loaders +async function loadOverview() { + try { + const data = await fetchAPI('/console/api/overview'); + + document.getElementById('calls-today').textContent = data.calls_today || 0; + document.getElementById('denies-today').textContent = data.denies_today || 0; + + const topToolsList = document.getElementById('top-tools'); + topToolsList.innerHTML = ''; + (data.top_tools || []).forEach(([tool, count]) => { + const li = document.createElement('li'); + li.textContent = `${tool}: ${count}`; + topToolsList.appendChild(li); + }); + + const topUsersList = document.getElementById('top-users'); + topUsersList.innerHTML = ''; + (data.top_users || []).forEach(([user, count]) => { + const li = document.createElement('li'); + li.textContent = `${user.substring(0, 16)}...: ${count}`; + topUsersList.appendChild(li); + }); + } catch (error) { + console.error('Failed to load overview:', error); + } +} + +async function loadEvents() { + try { + const limit = document.getElementById('events-limit').value; + const data = await fetchAPI(`/console/api/events?limit=${limit}`); + + const tbody = document.getElementById('events-tbody'); + tbody.innerHTML = ''; + + (data.events || []).forEach(event => { + const row = tbody.insertRow(); + row.insertCell().textContent = new Date(event.ts * 1000).toLocaleString(); + row.insertCell().textContent = event.user.substring(0, 16) + '...'; + row.insertCell().textContent = event.server || '-'; + row.insertCell().textContent = event.method || '-'; + row.insertCell().textContent = event.tool || '-'; + + const allowedCell = row.insertCell(); + allowedCell.textContent = event.allowed ? 'Yes' : 'No'; + allowedCell.className = event.allowed ? 'allowed-yes' : 'allowed-no'; + + row.insertCell().textContent = event.latency_ms ? event.latency_ms.toFixed(2) : '-'; + row.insertCell().textContent = event.error || '-'; + }); + } catch (error) { + console.error('Failed to load events:', error); + } +} + +async function loadServers() { + try { + const data = await fetchAPI('/console/api/servers'); + + const serversList = document.getElementById('servers-list'); + serversList.innerHTML = ''; + + const servers = data.servers || {}; + if (Object.keys(servers).length === 0) { + serversList.innerHTML = '

No MCP servers configured.

'; + return; + } + + Object.entries(servers).forEach(([name, config]) => { + const card = document.createElement('div'); + card.className = 'server-card'; + + const header = document.createElement('h3'); + header.textContent = name; + card.appendChild(header); + + const status = document.createElement('div'); + status.className = `server-status ${config.enabled ? 'enabled' : 'disabled'}`; + status.textContent = config.enabled ? 'Enabled' : 'Disabled'; + card.appendChild(status); + + const url = document.createElement('div'); + url.className = 'server-url'; + url.textContent = config.url; + card.appendChild(url); + + serversList.appendChild(card); + }); + } catch (error) { + console.error('Failed to load servers:', error); + } +} + +// Event handlers +document.addEventListener('DOMContentLoaded', () => { + // Check if token exists + const token = getToken(); + if (token) { + showDashboard(); + } else { + showLoginPrompt(); + } + + // Save token button + document.getElementById('save-token-btn').addEventListener('click', () => { + const tokenInput = document.getElementById('token-input'); + const token = tokenInput.value.trim(); + if (token) { + setToken(token); + tokenInput.value = ''; + showDashboard(); + } + }); + + // Logout button + document.getElementById('logout-btn').addEventListener('click', () => { + clearToken(); + showLoginPrompt(); + }); + + // Navigation buttons + document.querySelectorAll('.nav-btn').forEach(btn => { + btn.addEventListener('click', (e) => { + const viewName = e.target.dataset.view; + if (viewName) { + switchView(viewName); + } + }); + }); + + // Refresh events button + document.getElementById('refresh-events-btn').addEventListener('click', () => { + loadEvents(); + }); +}); diff --git a/console/static/index.html b/console/static/index.html new file mode 100644 index 0000000..9dc5e5e --- /dev/null +++ b/console/static/index.html @@ -0,0 +1,90 @@ + + + + + + Attach Gateway - MCP Console + + + +
+

Attach Gateway - MCP Console

+
+
+ +
+ + + +
+ + + + diff --git a/console/static/style.css b/console/static/style.css new file mode 100644 index 0000000..e3a4eca --- /dev/null +++ b/console/static/style.css @@ -0,0 +1,286 @@ +/* Attach Gateway MCP Console Styles */ + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; + background: #f5f5f5; + color: #333; + line-height: 1.6; +} + +header { + background: #2c3e50; + color: white; + padding: 1rem 2rem; + display: flex; + justify-content: space-between; + align-items: center; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +header h1 { + font-size: 1.5rem; + font-weight: 600; +} + +#auth-status { + font-size: 0.9rem; + color: #ecf0f1; +} + +main { + max-width: 1400px; + margin: 2rem auto; + padding: 0 2rem; +} + +/* Login prompt */ +#login-prompt { + background: white; + border-radius: 8px; + padding: 2rem; + max-width: 600px; + margin: 2rem auto; + box-shadow: 0 2px 8px rgba(0,0,0,0.1); +} + +#login-prompt h2 { + margin-bottom: 1rem; + color: #2c3e50; +} + +#login-prompt p { + margin-bottom: 1rem; + color: #666; +} + +#token-input { + width: 100%; + padding: 0.75rem; + border: 1px solid #ddd; + border-radius: 4px; + font-family: monospace; + font-size: 0.9rem; + margin-bottom: 1rem; + resize: vertical; +} + +button { + background: #3498db; + color: white; + border: none; + padding: 0.75rem 1.5rem; + border-radius: 4px; + cursor: pointer; + font-size: 1rem; + transition: background 0.2s; +} + +button:hover { + background: #2980b9; +} + +#logout-btn { + background: #e74c3c; +} + +#logout-btn:hover { + background: #c0392b; +} + +/* Navigation */ +nav { + background: white; + border-radius: 8px; + padding: 1rem; + margin-bottom: 2rem; + display: flex; + gap: 0.5rem; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +.nav-btn { + background: #ecf0f1; + color: #2c3e50; + padding: 0.5rem 1rem; +} + +.nav-btn:hover { + background: #bdc3c7; +} + +.nav-btn.active { + background: #3498db; + color: white; +} + +/* Views */ +.view { + background: white; + border-radius: 8px; + padding: 2rem; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +.view h2 { + margin-bottom: 1.5rem; + color: #2c3e50; + border-bottom: 2px solid #3498db; + padding-bottom: 0.5rem; +} + +/* Stats grid */ +.stats-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); + gap: 1rem; + margin-bottom: 2rem; +} + +.stat-card { + background: #ecf0f1; + padding: 1.5rem; + border-radius: 8px; + text-align: center; +} + +.stat-card h3 { + font-size: 0.9rem; + color: #666; + margin-bottom: 0.5rem; +} + +.stat-value { + font-size: 2.5rem; + font-weight: bold; + color: #2c3e50; +} + +.stats-row { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); + gap: 2rem; +} + +.stat-section h3 { + font-size: 1.1rem; + margin-bottom: 1rem; + color: #2c3e50; +} + +.stat-section ul { + list-style: none; +} + +.stat-section li { + padding: 0.5rem; + border-bottom: 1px solid #ecf0f1; +} + +/* Events table */ +.filters { + margin-bottom: 1rem; + display: flex; + gap: 1rem; + align-items: center; +} + +.filters label { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.filters input { + padding: 0.5rem; + border: 1px solid #ddd; + border-radius: 4px; + width: 100px; +} + +#events-table-container { + overflow-x: auto; +} + +#events-table { + width: 100%; + border-collapse: collapse; + font-size: 0.9rem; +} + +#events-table th { + background: #34495e; + color: white; + padding: 0.75rem; + text-align: left; + font-weight: 600; +} + +#events-table td { + padding: 0.75rem; + border-bottom: 1px solid #ecf0f1; +} + +#events-table tbody tr:hover { + background: #f8f9fa; +} + +.allowed-yes { + color: #27ae60; + font-weight: 600; +} + +.allowed-no { + color: #e74c3c; + font-weight: 600; +} + +/* Servers */ +#servers-list { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); + gap: 1rem; +} + +.server-card { + background: #ecf0f1; + padding: 1.5rem; + border-radius: 8px; +} + +.server-card h3 { + font-size: 1.2rem; + color: #2c3e50; + margin-bottom: 0.5rem; +} + +.server-status { + display: inline-block; + padding: 0.25rem 0.75rem; + border-radius: 4px; + font-size: 0.85rem; + font-weight: 600; + margin-bottom: 0.5rem; +} + +.server-status.enabled { + background: #27ae60; + color: white; +} + +.server-status.disabled { + background: #95a5a6; + color: white; +} + +.server-url { + font-family: monospace; + font-size: 0.85rem; + color: #666; + word-break: break-all; +} diff --git a/examples/agents/coder.py b/examples/agents/coder.py index a2833c7..61770ca 100644 --- a/examples/agents/coder.py +++ b/examples/agents/coder.py @@ -1,16 +1,21 @@ # examples/agents/coder.py (drop this in as a full replacement) +import os +import time +import uuid +from typing import Any, Dict, List + +import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel -import httpx, os, time, uuid -from typing import List, Dict, Any ENGINE_URL = os.getenv("ENGINE_URL", "http://127.0.0.1:11434") -app = FastAPI(title="Coder Agent") +app = FastAPI(title="Coder Agent") + class ChatRequest(BaseModel): model: str messages: List[Dict[str, Any]] - stream: bool = False + stream: bool = False def _fmt_error(text: str) -> Dict[str, Any]: @@ -22,7 +27,7 @@ def _fmt_error(text: str) -> Dict[str, Any]: "id": f"err-{uuid.uuid4().hex[:8]}", "object": "chat.completion", "created": int(time.time()), - "model": "coder-error", + "model": "coder-error", "choices": [ { "index": 0, @@ -80,4 +85,4 @@ async def chat(req: ChatRequest): if ("choices" not in payload) or not payload["choices"]: return _fmt_error(f"Invalid engine payload: {payload!r}") - return payload \ No newline at end of file + return payload diff --git a/examples/agents/planner.py b/examples/agents/planner.py index 9a8febe..f39f0d9 100644 --- a/examples/agents/planner.py +++ b/examples/agents/planner.py @@ -1,8 +1,9 @@ # examples/agents/planner.py +import os + +import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel -import httpx -import os app = FastAPI(title="Planner Agent") @@ -37,4 +38,4 @@ async def chat(req: ChatRequest): # surface a clean 502 for gateway diagnostics raise HTTPException(status_code=502, detail=f"Ollama request failed: {e}") - return resp.json() \ No newline at end of file + return resp.json() diff --git a/examples/demo_view_memory.py b/examples/demo_view_memory.py index 6f9f7b2..aefeb20 100644 --- a/examples/demo_view_memory.py +++ b/examples/demo_view_memory.py @@ -24,7 +24,9 @@ # Fetch the last 10 events, newest first result = ( - client.query.get("MemoryEvent", ["timestamp", "event", "user"]) # Fields that actually exist + client.query.get( + "MemoryEvent", ["timestamp", "event", "user"] + ) # Fields that actually exist .with_additional(["id"]) .with_limit(10) .do() @@ -43,4 +45,4 @@ print(json.dumps(o, indent=2)[:600], "...\n") objs = client.data_object.get(class_name="MemoryEvent", limit=1) -print(objs) \ No newline at end of file +print(objs) diff --git a/examples/flask_app/app.py b/examples/flask_app/app.py index 323acfb..cfae5fe 100644 --- a/examples/flask_app/app.py +++ b/examples/flask_app/app.py @@ -1,7 +1,8 @@ -from flask import Flask, request, jsonify -import httpx import os + +import httpx from dotenv import load_dotenv +from flask import Flask, jsonify, request # Load .env file load_dotenv() @@ -9,25 +10,27 @@ app = Flask(__name__) GATEWAY_URL = "http://localhost:8080" -@app.route('/chat', methods=['POST']) + +@app.route("/chat", methods=["POST"]) def chat(): # Get JWT from request - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if not auth_header: return {"error": "Authorization header required"}, 401 - + # Forward to Attach Gateway try: response = httpx.post( f"{GATEWAY_URL}/api/chat", json=request.json, headers={"Authorization": auth_header}, - timeout=30.0 + timeout=30.0, ) response.raise_for_status() return response.json() except httpx.RequestError as e: return {"error": "Gateway unavailable"}, 503 -if __name__ == '__main__': - app.run(debug=True, port=5000) \ No newline at end of file + +if __name__ == "__main__": + app.run(debug=True, port=5000) diff --git a/examples/langgraph_demo.py b/examples/langgraph_demo.py index 9f19bfd..9ceeab9 100644 --- a/examples/langgraph_demo.py +++ b/examples/langgraph_demo.py @@ -1,4 +1,5 @@ from __future__ import annotations + """ LangGraph → Attach-Gateway demo ──────────────────────────── @@ -9,26 +10,32 @@ $ uvicorn main:app --port 8080 # gateway running $ pip install langgraph>=0.0.48 langchain-core>=0.3.0 httpx """ -import asyncio, hashlib, json, os, time +import asyncio +import hashlib +import json +import os +import time from typing import List, TypedDict import httpx from langchain_core.messages import BaseMessage, HumanMessage -from langgraph.graph import StateGraph, END +from langgraph.graph import END, StateGraph # ───────────────────────── Config ────────────────────────── -JWT = os.environ["JWT"] +JWT = os.environ["JWT"] GW_URL = os.getenv("GW_URL", "http://127.0.0.1:8080") -SID = hashlib.sha256((JWT + "demo").encode()).hexdigest()[:16] +SID = hashlib.sha256((JWT + "demo").encode()).hexdigest()[:16] -HEADERS = {"Authorization": f"Bearer {JWT}"} +HEADERS = {"Authorization": f"Bearer {JWT}"} HEADERS_WITH_SESSION = HEADERS | {"X-Attach-Session": SID} + # ─────────────── Helpers: queue + poll Ollama ────────────── def lc_to_openai(msg: BaseMessage) -> dict: role_map = {"human": "user", "ai": "assistant"} return {"role": role_map.get(msg.type, msg.type), "content": msg.content} + async def queue_chat(payload: dict) -> str: async with httpx.AsyncClient() as cli: r = await cli.post( @@ -40,34 +47,39 @@ async def queue_chat(payload: dict) -> str: r.raise_for_status() return r.json()["task_id"] + async def wait_for_result(tid: str, every: float = 0.5) -> dict: async with httpx.AsyncClient() as cli: while True: r = await cli.get( - f"{GW_URL}/a2a/tasks/status/{tid}", - headers=HEADERS, timeout=10 + f"{GW_URL}/a2a/tasks/status/{tid}", headers=HEADERS, timeout=10 ) j = r.json() if j["state"] in {"done", "error"}: return j await asyncio.sleep(every) + async def ask_ollama(msgs: List[BaseMessage]) -> str: - tid = await queue_chat({ - "model": "tinyllama", - "messages": [lc_to_openai(m) for m in msgs], - "stream": False, - }) + tid = await queue_chat( + { + "model": "tinyllama", + "messages": [lc_to_openai(m) for m in msgs], + "stream": False, + } + ) res = await wait_for_result(tid) if res["state"] == "error": raise RuntimeError(res["result"]) return res["result"]["choices"][0]["message"]["content"] + # ─────────────── LangGraph definition ────────────────────── class State(TypedDict): messages: List[BaseMessage] reply: str | None + async def planner(state: State) -> State: if any(kw in state["messages"][-1].content.lower() for kw in ("code", "python")): state["reply"] = await ask_ollama(state["messages"]) @@ -75,12 +87,14 @@ async def planner(state: State) -> State: state["reply"] = "No code requested." return state + sg = StateGraph(State) sg.add_node("planner", planner) sg.set_entry_point("planner") sg.add_edge("planner", END) graph = sg.compile() + # ───────────────────────── Runner ────────────────────────── async def main() -> None: prompt = "Write python to sort a list" @@ -91,5 +105,6 @@ async def main() -> None: print(f"\nAssistant reply (took {time.perf_counter() - t0:.2f}s):\n") print(json.dumps(final["reply"], indent=2)) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/main.py b/main.py index 1a99d96..83df305 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,15 @@ import os +from contextlib import asynccontextmanager import weaviate from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware -from contextlib import asynccontextmanager -from a2a.routes import router as a2a_router import logs +from a2a.routes import router as a2a_router + logs_router = logs.router from middleware.auth import jwt_auth_mw from middleware.session import session_mw @@ -19,6 +20,7 @@ try: from middleware.quota import TokenQuotaMiddleware + QUOTA_AVAILABLE = True except ImportError: QUOTA_AVAILABLE = False @@ -104,12 +106,13 @@ async def lifespan(app: FastAPI): backend_selector = _select_backend() app.state.usage = get_usage_backend(backend_selector) mount_metrics(app) - + yield - - if hasattr(app.state.usage, 'aclose'): + + if hasattr(app.state.usage, "aclose"): await app.state.usage.aclose() + app = FastAPI(title="attach-gateway", lifespan=lifespan) # Add middleware in correct order (CORS outer-most) @@ -129,6 +132,7 @@ async def lifespan(app: FastAPI): app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw) app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw) + @app.get("/auth/config") async def auth_config(): return { @@ -137,6 +141,7 @@ async def auth_config(): "audience": os.getenv("OIDC_AUD"), } + app.include_router(a2a_router, prefix="/a2a") app.include_router(logs_router) app.include_router(mem_router) diff --git a/mcp/__init__.py b/mcp/__init__.py new file mode 100644 index 0000000..368eb36 --- /dev/null +++ b/mcp/__init__.py @@ -0,0 +1,5 @@ +""" +MCP Gateway module - opt-in model context protocol gateway +""" + +from __future__ import annotations diff --git a/mcp/config.py b/mcp/config.py new file mode 100644 index 0000000..614c591 --- /dev/null +++ b/mcp/config.py @@ -0,0 +1,123 @@ +""" +MCP server configuration loader. + +Config file: ~/.attach/mcp.json + +Schema: +{ + "version": 1, + "servers": { + "server_name": { + "enabled": true, + "url": "http://localhost:7001/mcp", + "headers": { + "Authorization": "env:NOTION_TOKEN", + "X-Custom": "literal-value" + } + } + } +} +""" + +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path +from typing import Any, Optional + +log = logging.getLogger(__name__) + + +def get_attach_dir() -> Path: + """Return ~/.attach directory, creating if needed.""" + attach_dir = Path.home() / ".attach" + attach_dir.mkdir(exist_ok=True) + return attach_dir + + +def get_mcp_config_path() -> Path: + """Return path to MCP config file.""" + return get_attach_dir() / "mcp.json" + + +def load_mcp_config() -> dict[str, Any]: + """ + Load MCP configuration from ~/.attach/mcp.json. + Returns empty config structure if file doesn't exist. + """ + path = get_mcp_config_path() + if not path.exists(): + return {"version": 1, "servers": {}} + + try: + with open(path, "r", encoding="utf-8") as f: + config = json.load(f) + return config + except (json.JSONDecodeError, IOError) as exc: + log.warning("Failed to load MCP config from %s: %s", path, exc) + return {"version": 1, "servers": {}} + + +def save_mcp_config(config: dict[str, Any]) -> None: + """Save MCP configuration to ~/.attach/mcp.json.""" + path = get_mcp_config_path() + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + except IOError as exc: + log.error("Failed to save MCP config to %s: %s", path, exc) + raise + + +def get_enabled_servers() -> dict[str, dict[str, Any]]: + """Return dict of enabled MCP servers with their configs.""" + config = load_mcp_config() + servers = config.get("servers", {}) + return { + name: server_config + for name, server_config in servers.items() + if server_config.get("enabled", False) + } + + +def resolve_header_value(value: str) -> str: + """ + Resolve header value, supporting env: prefix for environment variables. + + Examples: + "Bearer token123" -> "Bearer token123" + "env:NOTION_TOKEN" -> os.getenv("NOTION_TOKEN", "") + """ + if value.startswith("env:"): + var_name = value[4:] + resolved = os.getenv(var_name, "") + if not resolved: + log.warning("Environment variable %s not set for MCP header", var_name) + return resolved + return value + + +def get_server_headers(server_config: dict[str, Any]) -> dict[str, str]: + """ + Extract and resolve headers for a server configuration. + Never logs resolved secrets. + """ + headers_config = server_config.get("headers", {}) + resolved = {} + for key, value in headers_config.items(): + resolved_value = resolve_header_value(value) + if resolved_value: + resolved[key] = resolved_value + return resolved + + +def is_mcp_enabled() -> bool: + """ + Check if MCP gateway is enabled. + Returns True if ATTACH_ENABLE_MCP=true OR if ~/.attach/mcp.json exists. + """ + if os.getenv("ATTACH_ENABLE_MCP", "").lower() == "true": + return True + return get_mcp_config_path().exists() diff --git a/mcp/proxy.py b/mcp/proxy.py new file mode 100644 index 0000000..8c47c74 --- /dev/null +++ b/mcp/proxy.py @@ -0,0 +1,249 @@ +""" +MCP reverse proxy logic. + +Forwards JSON-RPC requests to configured MCP servers. +""" + +from __future__ import annotations + +import logging +import os +import time +from typing import Any, Optional + +import httpx + +from audit.sqlite import insert_mcp_event +from mcp.config import get_enabled_servers, get_server_headers +from mcp.quota import check_quota, record_tool_call + +log = logging.getLogger(__name__) + +# Default timeout for MCP upstream requests (seconds) +DEFAULT_TIMEOUT = 30.0 + + +def get_mcp_timeout() -> float: + """Get MCP proxy timeout from environment.""" + timeout_str = os.getenv("ATTACH_MCP_TIMEOUT", str(DEFAULT_TIMEOUT)) + try: + return float(timeout_str) + except ValueError: + log.warning( + "Invalid ATTACH_MCP_TIMEOUT=%s, using default %s", + timeout_str, + DEFAULT_TIMEOUT, + ) + return DEFAULT_TIMEOUT + + +def extract_tool_name(body: dict[str, Any]) -> Optional[str]: + """ + Extract tool name from JSON-RPC tools/call request. + + Expected structure: + { + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "github.create_issue", + ... + }, + "id": 1 + } + """ + if body.get("method") != "tools/call": + return None + + params = body.get("params", {}) + if isinstance(params, dict): + return params.get("name") + + return None + + +async def proxy_mcp_request( + server_name: str, + body: dict[str, Any], + user: str, +) -> tuple[int, dict[str, Any]]: + """ + Proxy a JSON-RPC request to the configured MCP server. + + Args: + server_name: Name of the MCP server + body: JSON-RPC request body + user: User identifier (from JWT sub claim) + + Returns: + (status_code, response_body) + + Notes: + - Enforces quota if method is "tools/call" + - Logs all requests to audit log + - Does NOT forward client Authorization header + - Uses configured headers from mcp.json + """ + start_time = time.time() + + # Check if server exists and is enabled + enabled_servers = get_enabled_servers() + if server_name not in enabled_servers: + insert_mcp_event( + user=user, + server=server_name, + method=None, + tool=None, + allowed=False, + error="server not found or disabled", + ) + return ( + 404, + { + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32001, + "message": f"MCP server '{server_name}' not found or disabled", + }, + }, + ) + + server_config = enabled_servers[server_name] + upstream_url = server_config.get("url") + if not upstream_url: + insert_mcp_event( + user=user, + server=server_name, + method=None, + tool=None, + allowed=False, + error="server URL not configured", + ) + return ( + 500, + { + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32002, + "message": f"MCP server '{server_name}' has no URL configured", + }, + }, + ) + + # Extract method and tool name for quota enforcement + method = body.get("method") + tool = extract_tool_name(body) if method == "tools/call" else None + + # Quota check for tools/call + if method == "tools/call" and tool: + allowed, error_msg = check_quota(user, tool) + if not allowed: + latency_ms = (time.time() - start_time) * 1000 + insert_mcp_event( + user=user, + server=server_name, + method=method, + tool=tool, + allowed=False, + latency_ms=latency_ms, + error=error_msg, + ) + return ( + 200, + { + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32029, + "message": "tool quota exceeded", + "data": { + "tool": tool, + "error": error_msg, + }, + }, + }, + ) + + # Forward request to upstream + headers = get_server_headers(server_config) + headers["Content-Type"] = "application/json" + + timeout = get_mcp_timeout() + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(upstream_url, json=body, headers=headers) + + latency_ms = (time.time() - start_time) * 1000 + + # Record successful call + if method == "tools/call" and tool: + record_tool_call(user, tool) + + insert_mcp_event( + user=user, + server=server_name, + method=method, + tool=tool, + allowed=True, + latency_ms=latency_ms, + error=None if response.is_success else f"HTTP {response.status_code}", + ) + + # Return upstream response + try: + response_json = response.json() + except Exception: + response_json = {"error": "upstream returned non-JSON response"} + + return (response.status_code, response_json) + + except httpx.TimeoutException: + latency_ms = (time.time() - start_time) * 1000 + insert_mcp_event( + user=user, + server=server_name, + method=method, + tool=tool, + allowed=False, + latency_ms=latency_ms, + error="upstream timeout", + ) + return ( + 504, + { + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32003, + "message": f"MCP server '{server_name}' timeout after {timeout}s", + }, + }, + ) + + except httpx.RequestError as exc: + latency_ms = (time.time() - start_time) * 1000 + error_msg = f"upstream error: {exc}" + insert_mcp_event( + user=user, + server=server_name, + method=method, + tool=tool, + allowed=False, + latency_ms=latency_ms, + error=error_msg, + ) + return ( + 502, + { + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32004, + "message": f"Failed to reach MCP server '{server_name}'", + "data": str(exc), + }, + }, + ) diff --git a/mcp/quota.py b/mcp/quota.py new file mode 100644 index 0000000..54aed0c --- /dev/null +++ b/mcp/quota.py @@ -0,0 +1,134 @@ +""" +MCP tool call quota enforcement. + +Policy file: ~/.attach/mcp_policy.json + +Schema: +{ + "version": 1, + "enabled": true, + "per_user_daily_tool_calls": { + "github.*": 200, + "notion.*": 100, + "*": 1000 + } +} + +Quota enforcement: +- Only applies when method == "tools/call" +- Uses fnmatch glob patterns for tool names +- Day boundary: UTC midnight +- Counters stored in SQLite (audit/sqlite.py) +""" + +from __future__ import annotations + +import fnmatch +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +from audit.sqlite import get_quota_count, increment_quota_count +from mcp.config import get_attach_dir + +log = logging.getLogger(__name__) + + +def get_policy_path() -> Path: + """Return path to MCP policy file.""" + return get_attach_dir() / "mcp_policy.json" + + +def load_policy() -> dict[str, Any]: + """Load MCP quota policy from ~/.attach/mcp_policy.json.""" + path = get_policy_path() + if not path.exists(): + return {"version": 1, "enabled": False, "per_user_daily_tool_calls": {}} + + try: + with open(path, "r", encoding="utf-8") as f: + policy = json.load(f) + return policy + except (json.JSONDecodeError, IOError) as exc: + log.warning("Failed to load MCP policy from %s: %s", path, exc) + return {"version": 1, "enabled": False, "per_user_daily_tool_calls": {}} + + +def is_quota_enabled() -> bool: + """Check if quota enforcement is enabled.""" + policy = load_policy() + return policy.get("enabled", False) + + +def get_tool_limit(tool: str) -> Optional[int]: + """ + Get daily tool call limit for a given tool name using glob matching. + Returns None if no limit applies. + + Matching priority: + 1. Exact match + 2. First glob pattern that matches + 3. Wildcard "*" if present + """ + policy = load_policy() + limits = policy.get("per_user_daily_tool_calls", {}) + + # Exact match first + if tool in limits: + return limits[tool] + + # Try glob patterns + for pattern, limit in limits.items(): + if "*" in pattern or "?" in pattern or "[" in pattern: + if fnmatch.fnmatch(tool, pattern): + return limit + + # Fallback to wildcard + if "*" in limits: + return limits["*"] + + return None + + +def get_current_date_utc() -> str: + """Return current date in UTC as YYYY-MM-DD string.""" + now_utc = datetime.now(timezone.utc) + return now_utc.strftime("%Y-%m-%d") + + +def check_quota(user: str, tool: str) -> tuple[bool, Optional[str]]: + """ + Check if user is within quota for tool. + + Returns: + (allowed: bool, error_msg: Optional[str]) + - (True, None) if allowed + - (False, "quota exceeded...") if denied + """ + if not is_quota_enabled(): + return (True, None) + + limit = get_tool_limit(tool) + if limit is None: + # No limit configured for this tool + return (True, None) + + date_utc = get_current_date_utc() + current_count = get_quota_count(user, tool, date_utc) + + if current_count >= limit: + error_msg = f"tool quota exceeded: {tool} limit={limit} used={current_count}" + return (False, error_msg) + + return (True, None) + + +def record_tool_call(user: str, tool: str) -> None: + """Record a tool call in quota counters.""" + if not is_quota_enabled(): + return + + date_utc = get_current_date_utc() + increment_quota_count(user, tool, date_utc) diff --git a/mcp/router.py b/mcp/router.py new file mode 100644 index 0000000..e25cedb --- /dev/null +++ b/mcp/router.py @@ -0,0 +1,90 @@ +""" +FastAPI router for MCP gateway endpoints. + +Routes: + GET /mcp - List enabled MCP servers + POST /mcp/{name} - Proxy JSON-RPC request to named server +""" + +from __future__ import annotations + +import logging +from typing import Any + +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +from mcp.config import get_enabled_servers +from mcp.proxy import proxy_mcp_request + +log = logging.getLogger(__name__) + +router = APIRouter(prefix="/mcp", tags=["mcp"]) + + +@router.get("") +async def list_mcp_servers(): + """ + List all enabled MCP servers. + + Returns: + { + "servers": { + "server_name": { + "enabled": true, + "url": "http://..." + } + } + } + """ + enabled = get_enabled_servers() + + # Return sanitized view (no headers) + servers = {} + for name, config in enabled.items(): + servers[name] = { + "enabled": config.get("enabled", False), + "url": config.get("url", ""), + } + + return {"servers": servers} + + +@router.post("/{server}") +async def proxy_to_mcp_server(server: str, request: Request): + """ + Proxy a JSON-RPC request to the named MCP server. + + Args: + server: Server name from URL path + request: FastAPI request containing JSON-RPC body + + Returns: + JSON-RPC response from upstream server + + Notes: + - Requires Bearer token authentication (JWT) + - Enforces quota for tools/call method + - Logs all requests to audit log + """ + # Extract user from request.state (set by auth middleware) + user = getattr(request.state, "sub", "unknown") + + # Parse JSON body + try: + body: dict[str, Any] = await request.json() + except Exception as exc: + log.warning("Failed to parse JSON-RPC body: %s", exc) + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": "Parse error: invalid JSON"}, + }, + ) + + # Proxy the request + status_code, response_body = await proxy_mcp_request(server, body, user) + + return JSONResponse(status_code=status_code, content=response_body) diff --git a/mem/__init__.py b/mem/__init__.py index f762f4f..c02c4ac 100644 --- a/mem/__init__.py +++ b/mem/__init__.py @@ -1,10 +1,11 @@ from __future__ import annotations -# mem/__init__.py - import asyncio import os -from typing import Protocol, Optional, Union +from typing import Optional, Protocol, Union + +# mem/__init__.py + class MemoryBackend(Protocol): diff --git a/middleware/auth.py b/middleware/auth.py index 81c7372..4e41e2b 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -4,6 +4,7 @@ This file *must* live inside the project's `middleware/` package so that `from middleware.auth import jwt_auth_mw` works. """ + from __future__ import annotations import os @@ -11,7 +12,10 @@ from fastapi import HTTPException, Request from fastapi.responses import JSONResponse -from auth.oidc import verify_jwt, verify_jwt_with_exchange # your existing verifier (RS256 / ES256 only) +from auth.oidc import ( # your existing verifier (RS256 / ES256 only) + verify_jwt, + verify_jwt_with_exchange, +) _CLOCK_SKEW = 60 # seconds @@ -23,6 +27,11 @@ "/openapi.json", } +# Path prefixes that don't require authentication (for console static assets) +EXCLUDED_PATH_PREFIXES = [ + "/console/static/", +] + async def jwt_auth_mw(request: Request, call_next): """ @@ -35,11 +44,20 @@ async def jwt_auth_mw(request: Request, call_next): # Skip authentication for OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return await call_next(request) - + # Skip authentication for excluded paths if request.url.path in EXCLUDED_PATHS: return await call_next(request) + # Skip authentication for console (unauthenticated landing page) + if request.url.path == "/console": + return await call_next(request) + + # Skip authentication for excluded path prefixes (console static assets) + for prefix in EXCLUDED_PATH_PREFIXES: + if request.url.path.startswith(prefix): + return await call_next(request) + auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): return JSONResponse(status_code=401, content={"detail": "Missing Bearer token"}) @@ -49,7 +67,7 @@ async def jwt_auth_mw(request: Request, call_next): try: # Use sync version unless Descope exchange is explicitly enabled if os.getenv("ENABLE_DESCOPE_EXCHANGE", "false").lower() == "true": - claims = await verify_jwt_with_exchange(token, leeway=_CLOCK_SKEW) + claims = await verify_jwt_with_exchange(token, leeway=_CLOCK_SKEW) else: claims = verify_jwt(token, leeway=_CLOCK_SKEW) # original sync version except Exception as exc: diff --git a/middleware/quota.py b/middleware/quota.py index 15af934..7c4b7fe 100644 --- a/middleware/quota.py +++ b/middleware/quota.py @@ -165,6 +165,7 @@ def _is_textual(mime: str) -> bool: # Token-count helpers # --------------------------------------------------------------------------- + def _encoder_for_model(model: str): """Return a tiktoken encoder, falling back to byte count.""" if tiktoken is None: # fallback: 1 token ≈ 4 bytes diff --git a/middleware/session.py b/middleware/session.py index b6ac8c0..b70a73c 100644 --- a/middleware/session.py +++ b/middleware/session.py @@ -18,6 +18,12 @@ "/openapi.json", } +# Path prefixes that don't require authentication +EXCLUDED_PATH_PREFIXES = [ + "/console/static/", +] + + def _session_id(sub: str, user_agent: str) -> str: return hashlib.sha256(f"{sub}:{user_agent}".encode()).hexdigest() @@ -26,14 +32,23 @@ async def session_mw(request: Request, call_next): # Skip session middleware for excluded paths if request.url.path in EXCLUDED_PATHS: return await call_next(request) - + + # Skip session middleware for console + if request.url.path == "/console": + return await call_next(request) + + # Skip session middleware for excluded path prefixes + for prefix in EXCLUDED_PATH_PREFIXES: + if request.url.path.startswith(prefix): + return await call_next(request) + # Let the auth middleware handle authentication first response: Response = await call_next(request) - + # Only set session ID if sub is available (after auth middleware runs) if hasattr(request.state, "sub"): sid = _session_id(request.state.sub, request.headers.get("user-agent", "")) request.state.sid = sid # expose to downstream handlers response.headers["X-Attach-Session"] = sid[:16] # expose *truncated* sid - response.headers["X-Attach-User"] = request.state.sub[:32] - return response \ No newline at end of file + response.headers["X-Attach-User"] = request.state.sub[:32] + return response diff --git a/proxy/engine.py b/proxy/engine.py index ab02c21..47f8e6c 100644 --- a/proxy/engine.py +++ b/proxy/engine.py @@ -59,16 +59,20 @@ async def proxy_to_engine(request: Request): try: return StreamingResponse( - _upstream_stream(request.method, upstream_url, headers=headers, payload=body), + _upstream_stream( + request.method, upstream_url, headers=headers, payload=body + ), media_type="application/json", ) except httpx.HTTPStatusError as exc: # Bubble the upstream status so callers can act accordingly - raise HTTPException(status_code=exc.response.status_code, detail=exc.response.text) + raise HTTPException( + status_code=exc.response.status_code, detail=exc.response.text + ) except Exception as exc: # Log & hide internals from the client # (LOGGER omitted for brevity – add one if you like) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="Upstream chat engine error", - ) from exc \ No newline at end of file + ) from exc diff --git a/pyproject.toml b/pyproject.toml index d27ef50..feb6a95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ attach-gateway = "attach.__main__:main" # Include your existing modules with cleaner targeting [tool.hatch.build.targets.wheel] -packages = ["attach", "auth", "middleware", "mem", "proxy", "a2a", "attach_pydid", "usage", "utils", "logs"] +packages = ["attach", "auth", "middleware", "mem", "proxy", "a2a", "attach_pydid", "usage", "utils", "logs", "mcp", "audit", "console"] # Dynamic version from attach/__init__.py diff --git a/script/dev_login.py b/script/dev_login.py index ded7b67..1393323 100644 --- a/script/dev_login.py +++ b/script/dev_login.py @@ -1,16 +1,21 @@ -import os, httpx, sys, json +import json +import os +import sys + +import httpx resp = httpx.post( f"https://{os.getenv('AUTH0_DOMAIN')}/oauth/token", json={ - "client_id": os.getenv("AUTH0_CLIENT"), + "client_id": os.getenv("AUTH0_CLIENT"), "client_secret": os.getenv("AUTH0_SECRET"), # add to .env - "audience": os.getenv("OIDC_AUD"), - "grant_type": "client_credentials", + "audience": os.getenv("OIDC_AUD"), + "grant_type": "client_credentials", }, ).json() if "access_token" not in resp: - sys.stderr.write(json.dumps(resp, indent=2) + "\n"); sys.exit(1) + sys.stderr.write(json.dumps(resp, indent=2) + "\n") + sys.exit(1) print(resp["access_token"]) diff --git a/tests/test_console_auth.py b/tests/test_console_auth.py new file mode 100644 index 0000000..0694eab --- /dev/null +++ b/tests/test_console_auth.py @@ -0,0 +1,186 @@ +""" +Test console auth model. + +- /console should be accessible without auth (static HTML) +- /console/static/* should be accessible without auth +- /console/api/* should require JWT Bearer token +""" + +import json +import os + +import pytest +from httpx import ASGITransport, AsyncClient + +os.environ["OIDC_ISSUER"] = "https://test.auth0.com/" +os.environ["OIDC_AUD"] = "test-api" +os.environ["ATTACH_ENABLE_MCP"] = "true" + +from jose import JWTError + +import auth.oidc +import middleware.auth +from attach.gateway import create_app + +DUMMY_GOOD_TOKEN = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.s3cr3t" +) + + +@pytest.fixture(autouse=True) +def stub_verify_jwt(monkeypatch): + """Stub JWT verification.""" + + def fake_verify_sync(token: str, *, leeway: int = 60): + if token == DUMMY_GOOD_TOKEN: + return {"sub": "test-user"} + raise JWTError("invalid token") + + async def fake_verify_async(token: str, *, leeway: int = 60): + if token == DUMMY_GOOD_TOKEN: + return {"sub": "test-user"} + raise JWTError("invalid token") + + monkeypatch.setattr(auth.oidc, "verify_jwt", fake_verify_sync) + monkeypatch.setattr(auth.oidc, "verify_jwt_with_exchange", fake_verify_async) + monkeypatch.setattr(middleware.auth, "verify_jwt", fake_verify_sync) + monkeypatch.setattr(middleware.auth, "verify_jwt_with_exchange", fake_verify_async) + + +@pytest.fixture +def temp_attach_dir(monkeypatch, tmp_path): + """Override ~/.attach directory with temp dir.""" + attach_dir = tmp_path / "attach" + attach_dir.mkdir() + + import audit.sqlite + import mcp.config + + monkeypatch.setattr(mcp.config, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(audit.sqlite, "get_attach_dir", lambda: attach_dir) + + # Create mcp.json to enable MCP + mcp_config = {"version": 1, "servers": {}} + (attach_dir / "mcp.json").write_text(json.dumps(mcp_config)) + + # Initialize audit DB + from audit.sqlite import init_db + + init_db() + + return attach_dir + + +@pytest.mark.asyncio +async def test_console_landing_page_no_auth(temp_attach_dir): + """Test that /console is accessible without authentication.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Access /console without auth + response = await client.get("/console") + + # Should return 200 (HTML page loads) + assert response.status_code == 200 + # Should contain HTML content + assert ( + b"html" in response.content.lower() + or b" None: self.api_key = api_key self.base_url = os.getenv("OPENMETER_URL", "https://openmeter.cloud") - + # Use httpx instead of buggy OpenMeter SDK try: import httpx - self.client = httpx.AsyncClient( - timeout=30.0 - ) + + self.client = httpx.AsyncClient(timeout=30.0) except ImportError as exc: raise ImportError("httpx is required for OpenMeter") from exc async def aclose(self) -> None: """Close the underlying HTTP client.""" - if hasattr(self.client, 'aclose'): + if hasattr(self.client, "aclose"): await self.client.aclose() async def record(self, **evt) -> None: @@ -116,45 +115,53 @@ async def record(self, **evt) -> None: except ImportError as exc: return - base_time = datetime.now(timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z") + base_time = ( + datetime.now(timezone.utc) + .isoformat(timespec="milliseconds") + .replace("+00:00", "Z") + ) user = evt.get("user") model = evt.get("model") - + tokens_in = int(evt.get("tokens_in", 0) or 0) tokens_out = int(evt.get("tokens_out", 0) or 0) # Send separate events for input and output tokens events_to_send = [] - + if tokens_in > 0: - events_to_send.append({ - "specversion": "1.0", - "type": "prompt", # ← Changed from "tokens" to "prompt" - "id": str(uuid4()), - "time": base_time, - "source": "attach-gateway", - "subject": user, - "data": { - "tokens": tokens_in, - "model": model, - "type": "input" # ← This stays the same + events_to_send.append( + { + "specversion": "1.0", + "type": "prompt", # ← Changed from "tokens" to "prompt" + "id": str(uuid4()), + "time": base_time, + "source": "attach-gateway", + "subject": user, + "data": { + "tokens": tokens_in, + "model": model, + "type": "input", # ← This stays the same + }, } - }) - + ) + if tokens_out > 0: - events_to_send.append({ - "specversion": "1.0", - "type": "prompt", - "id": str(uuid4()), - "time": base_time, - "source": "attach-gateway", - "subject": user, - "data": { - "tokens": tokens_out, # ← Single tokens field - "model": model, - "type": "output" # ← Add type field + events_to_send.append( + { + "specversion": "1.0", + "type": "prompt", + "id": str(uuid4()), + "time": base_time, + "source": "attach-gateway", + "subject": user, + "data": { + "tokens": tokens_out, # ← Single tokens field + "model": model, + "type": "output", # ← Add type field + }, } - }) + ) # Send each event for event in events_to_send: @@ -164,12 +171,12 @@ async def record(self, **evt) -> None: json=event, headers={ "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/cloudevents+json" - } + "Content-Type": "application/cloudevents+json", + }, ) - + if response.status_code not in [200, 201, 202, 204]: logger.warning(f"OpenMeter error: {response.status_code}") - + except Exception as exc: logger.warning("OpenMeter request failed: %s", exc) diff --git a/usage/factory.py b/usage/factory.py index 61ed842..81663e2 100644 --- a/usage/factory.py +++ b/usage/factory.py @@ -2,9 +2,9 @@ """Factory for usage backends.""" +import logging import os import warnings -import logging from .backends import ( AbstractUsageBackend, @@ -41,17 +41,19 @@ def get_usage_backend(kind: str) -> AbstractUsageBackend: "Prometheus metering unavailable: %s – " "falling back to NullUsageBackend. " "Install with: pip install 'attach-dev[usage]'", - exc + exc, ) return NullUsageBackend() if kind == "openmeter": - # fail-fast on bad config + # Graceful fallback if API key is missing if not os.getenv("OPENMETER_API_KEY"): - raise RuntimeError( - "USAGE_METERING=openmeter requires OPENMETER_API_KEY. " - "Set the variable or change USAGE_METERING=null to disable." + log.warning( + "USAGE_METERING=openmeter but OPENMETER_API_KEY not set. " + "Falling back to NullUsageBackend. " + "Set OPENMETER_API_KEY to enable OpenMeter metering." ) + return NullUsageBackend() return OpenMeterBackend() # exceptions inside bubble up return NullUsageBackend() From 306fe3899ab884e1a348a7e428763f0685d67349 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Tue, 6 Jan 2026 21:44:49 +0500 Subject: [PATCH 2/5] fixed bugs --- README.md | 8 +++-- attach/cli_claude.py | 63 +++++++++++++++-------------------- audit/sqlite.py | 53 +++++++++++++++++++++++++---- mcp/config.py | 14 ++++++-- mcp/quota.py | 39 ++++++++++++++++------ middleware/auth.py | 3 +- middleware/session.py | 3 +- tests/test_mcp_optin.py | 6 +++- tests/test_mcp_proxy_quota.py | 3 ++ 9 files changed, 130 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 4049418..8a59b3a 100644 --- a/README.md +++ b/README.md @@ -160,8 +160,12 @@ attach-gateway --port 8080 # Generate Claude Code configuration commands attach-gateway claude install --project . -# Or manually add servers: -claude mcp add --transport http --name "notion" --url "http://localhost:8080/mcp/notion" +# Or manually add servers (using positional args): +claude mcp add --transport http notion http://localhost:8080/mcp/notion + +# If your Claude Code version supports authorization headers: +# export JWT= +# claude mcp add --transport http notion http://localhost:8080/mcp/notion --header "Authorization: Bearer $JWT" ``` ### Use the Console UI diff --git a/attach/cli_claude.py b/attach/cli_claude.py index e6910ce..1fd3a22 100644 --- a/attach/cli_claude.py +++ b/attach/cli_claude.py @@ -25,11 +25,10 @@ def claude_group(): @click.option( "--project", default=".", help="Project directory (default: current directory)" ) -@click.option("--bearer", help="Bearer token for Authorization header (optional)") @click.option( "--write-file", is_flag=True, help="Write .mcp.json file directly (experimental)" ) -def install_claude(project: str, bearer: str, write_file: bool): +def install_claude(project: str, write_file: bool): """ Generate Claude Code MCP configuration. @@ -37,6 +36,8 @@ def install_claude(project: str, bearer: str, write_file: bool): Default mode: Prints 'claude mcp add' commands to run. --write-file mode: Writes /.mcp.json directly (experimental, schema may be outdated). + + Note: Bearer tokens are never printed. Use $JWT environment variable. """ enabled = get_enabled_servers() @@ -51,44 +52,36 @@ def install_claude(project: str, bearer: str, write_file: bool): gateway_url = os.getenv("ATTACH_GATEWAY_URL", "http://localhost:8080") if write_file: - _write_mcp_json_file(project, enabled, gateway_url, bearer) + _write_mcp_json_file(project, enabled, gateway_url) else: - _print_claude_commands(enabled, gateway_url, bearer) + _print_claude_commands(enabled, gateway_url) -def _print_claude_commands(servers: dict, gateway_url: str, bearer: str): - """Print 'claude mcp add' commands for each server.""" - click.echo("Run the following commands to add MCP servers to Claude Code:\n") +def _print_claude_commands(servers: dict, gateway_url: str): + """Print 'claude mcp add' commands for each server (positional args form).""" + click.echo("# Set your JWT token in the environment first:") + click.echo("# export JWT=") + click.echo() + click.echo("# Run these commands to add MCP servers to Claude Code:\n") for server_name in servers.keys(): server_url = f"{gateway_url}/mcp/{server_name}" - if bearer: - # Check if Claude Code supports headers in HTTP transport - click.echo(f"# Note: If Claude Code HTTP transport supports headers:") - click.echo( - f'claude mcp add --transport http --name "{server_name}" --url "{server_url}" --header "Authorization: Bearer {bearer}"' - ) - click.echo() - click.echo( - f"# If headers are not supported, you'll need to configure auth separately or use a different approach." - ) - else: - click.echo( - f'claude mcp add --transport http --name "{server_name}" --url "{server_url}"' - ) - + # Use positional args (most common CLI pattern) + click.echo(f"claude mcp add --transport http {server_name} {server_url}") click.echo() - if not bearer: - click.echo( - "Tip: Pass --bearer to include Authorization header in commands (if supported by Claude)." - ) + click.echo("# If Claude Code HTTP transport supports authorization headers:") + click.echo( + '# claude mcp add --transport http --header "Authorization: Bearer $JWT"' + ) + click.echo() + click.echo( + "# Note: JWT is passed via environment variable to avoid exposing tokens in shell history." + ) -def _write_mcp_json_file( - project_dir: str, servers: dict, gateway_url: str, bearer: str -): +def _write_mcp_json_file(project_dir: str, servers: dict, gateway_url: str): """ Write .mcp.json file for Claude Code (experimental). @@ -114,14 +107,6 @@ def _write_mcp_json_file( "url": server_url, } - if bearer: - # This may not be the correct schema - Claude Code's HTTP transport may not support headers - server_config["headers"] = {"Authorization": f"Bearer {bearer}"} - click.echo( - f"Warning: Added Authorization header, but Claude Code HTTP transport may not support headers.", - err=True, - ) - mcp_config["mcpServers"][server_name] = server_config try: @@ -136,6 +121,10 @@ def _write_mcp_json_file( click.echo( "Prefer using 'claude mcp add' commands for guaranteed compatibility." ) + click.echo() + click.echo( + "Authorization headers are not included. Configure auth separately if needed." + ) except IOError as exc: click.echo(f"Error writing {mcp_json_path}: {exc}", err=True) diff --git a/audit/sqlite.py b/audit/sqlite.py index 8e1ae8d..6e749cc 100644 --- a/audit/sqlite.py +++ b/audit/sqlite.py @@ -47,10 +47,13 @@ def get_db_path() -> Path: def init_db() -> None: """Initialize audit database schema.""" db_path = get_db_path() - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=5.0) try: cursor = conn.cursor() + # Enable WAL mode for better concurrency + cursor.execute("PRAGMA journal_mode=WAL;") + # MCP events table cursor.execute( """ @@ -113,7 +116,7 @@ def insert_mcp_event( """Insert an MCP event into audit log.""" db_path = get_db_path() try: - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=5.0) try: cursor = conn.cursor() cursor.execute( @@ -150,7 +153,7 @@ def query_mcp_events( return [] try: - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=5.0) conn.row_factory = sqlite3.Row try: cursor = conn.cursor() @@ -201,7 +204,7 @@ def overview_stats() -> dict[str, Any]: } try: - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=5.0) conn.row_factory = sqlite3.Row try: cursor = conn.cursor() @@ -276,7 +279,7 @@ def get_quota_count(user: str, tool: str, date_utc: str) -> int: return 0 try: - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=5.0) try: cursor = conn.cursor() cursor.execute( @@ -296,7 +299,7 @@ def increment_quota_count(user: str, tool: str, date_utc: str) -> None: """Increment quota counter for a user/tool/day.""" db_path = get_db_path() try: - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=5.0) try: cursor = conn.cursor() cursor.execute( @@ -313,3 +316,41 @@ def increment_quota_count(user: str, tool: str, date_utc: str) -> None: conn.close() except sqlite3.Error as exc: log.error("Failed to increment quota count: %s", exc) + + +def atomic_increment_and_get_quota_count(user: str, tool: str, date_utc: str) -> int: + """ + Atomically increment quota counter and return the NEW count. + + This prevents TOCTOU race conditions by incrementing first, then returning + the new count so the caller can check if the limit was exceeded. + + Returns: + The count AFTER increment (1-based). Returns 0 on error. + """ + db_path = get_db_path() + try: + conn = sqlite3.connect(db_path, timeout=5.0) + try: + cursor = conn.cursor() + # Use RETURNING clause to atomically increment and get new value + # SQLite 3.35+ supports RETURNING + cursor.execute( + """ + INSERT INTO mcp_counters (date_utc, user, tool, count) + VALUES (?, ?, ?, 1) + ON CONFLICT(date_utc, user, tool) + DO UPDATE SET count = count + 1 + RETURNING count + """, + (date_utc, user, tool), + ) + row = cursor.fetchone() + conn.commit() + return row[0] if row else 1 + finally: + conn.close() + except sqlite3.Error as exc: + log.error("Failed to atomic increment quota count: %s", exc) + # On error, return 0 (allow the request - fail open) + return 0 diff --git a/mcp/config.py b/mcp/config.py index 614c591..e9d6ee7 100644 --- a/mcp/config.py +++ b/mcp/config.py @@ -30,16 +30,21 @@ log = logging.getLogger(__name__) +def _attach_dir_path() -> Path: + """Return ~/.attach directory path (does NOT create it).""" + return Path.home() / ".attach" + + def get_attach_dir() -> Path: """Return ~/.attach directory, creating if needed.""" - attach_dir = Path.home() / ".attach" + attach_dir = _attach_dir_path() attach_dir.mkdir(exist_ok=True) return attach_dir def get_mcp_config_path() -> Path: - """Return path to MCP config file.""" - return get_attach_dir() / "mcp.json" + """Return path to MCP config file (does NOT create ~/.attach).""" + return _attach_dir_path() / "mcp.json" def load_mcp_config() -> dict[str, Any]: @@ -117,7 +122,10 @@ def is_mcp_enabled() -> bool: """ Check if MCP gateway is enabled. Returns True if ATTACH_ENABLE_MCP=true OR if ~/.attach/mcp.json exists. + + Note: This function does NOT create ~/.attach directory as a side effect. """ if os.getenv("ATTACH_ENABLE_MCP", "").lower() == "true": return True + # Use get_mcp_config_path() which doesn't create the directory return get_mcp_config_path().exists() diff --git a/mcp/quota.py b/mcp/quota.py index 54aed0c..7b6a027 100644 --- a/mcp/quota.py +++ b/mcp/quota.py @@ -30,7 +30,7 @@ from pathlib import Path from typing import Any, Optional -from audit.sqlite import get_quota_count, increment_quota_count +from audit.sqlite import atomic_increment_and_get_quota_count, increment_quota_count from mcp.config import get_attach_dir log = logging.getLogger(__name__) @@ -69,8 +69,8 @@ def get_tool_limit(tool: str) -> Optional[int]: Matching priority: 1. Exact match - 2. First glob pattern that matches - 3. Wildcard "*" if present + 2. First glob pattern that matches (excluding bare "*") + 3. Wildcard "*" if present (catch-all fallback) """ policy = load_policy() limits = policy.get("per_user_daily_tool_calls", {}) @@ -79,13 +79,16 @@ def get_tool_limit(tool: str) -> Optional[int]: if tool in limits: return limits[tool] - # Try glob patterns + # Try glob patterns (skip bare "*" - it's handled as fallback) for pattern, limit in limits.items(): + # Skip the bare "*" pattern - we apply it only as final fallback + if pattern == "*": + continue if "*" in pattern or "?" in pattern or "[" in pattern: if fnmatch.fnmatch(tool, pattern): return limit - # Fallback to wildcard + # Fallback to wildcard catch-all (only if no specific pattern matched) if "*" in limits: return limits["*"] @@ -102,6 +105,9 @@ def check_quota(user: str, tool: str) -> tuple[bool, Optional[str]]: """ Check if user is within quota for tool. + NOTE: This only checks the quota without incrementing. Use check_and_reserve_quota() + for atomic check-and-increment to prevent TOCTOU race conditions. + Returns: (allowed: bool, error_msg: Optional[str]) - (True, None) if allowed @@ -116,19 +122,30 @@ def check_quota(user: str, tool: str) -> tuple[bool, Optional[str]]: return (True, None) date_utc = get_current_date_utc() - current_count = get_quota_count(user, tool, date_utc) + # Use atomic increment to get current count and reserve our slot + new_count = atomic_increment_and_get_quota_count(user, tool, date_utc) - if current_count >= limit: - error_msg = f"tool quota exceeded: {tool} limit={limit} used={current_count}" + # new_count is the count AFTER increment (1-based) + # So if limit=1, we allow new_count=1 but deny new_count=2 + if new_count > limit: + error_msg = f"tool quota exceeded: {tool} limit={limit} used={new_count}" return (False, error_msg) return (True, None) def record_tool_call(user: str, tool: str) -> None: - """Record a tool call in quota counters.""" + """ + Record a tool call in quota counters. + + NOTE: This is now a no-op when quota is enabled because check_quota() + atomically increments the counter. Kept for backwards compatibility + and for cases where quota is disabled but tracking is still desired. + """ if not is_quota_enabled(): + # Only increment if quota is disabled (for tracking without enforcement) return - date_utc = get_current_date_utc() - increment_quota_count(user, tool, date_utc) + # When quota is enabled, the counter was already incremented in check_quota() + # No need to increment again + pass diff --git a/middleware/auth.py b/middleware/auth.py index 4e41e2b..9e76f2f 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -50,7 +50,8 @@ async def jwt_auth_mw(request: Request, call_next): return await call_next(request) # Skip authentication for console (unauthenticated landing page) - if request.url.path == "/console": + # Note: Check both with and without trailing slash to handle URL variations + if request.url.path in ("/console", "/console/"): return await call_next(request) # Skip authentication for excluded path prefixes (console static assets) diff --git a/middleware/session.py b/middleware/session.py index b70a73c..2e569c6 100644 --- a/middleware/session.py +++ b/middleware/session.py @@ -34,7 +34,8 @@ async def session_mw(request: Request, call_next): return await call_next(request) # Skip session middleware for console - if request.url.path == "/console": + # Note: Check both with and without trailing slash to handle URL variations + if request.url.path in ("/console", "/console/"): return await call_next(request) # Skip session middleware for excluded path prefixes diff --git a/tests/test_mcp_optin.py b/tests/test_mcp_optin.py index 9e7bf48..f7f6d8f 100644 --- a/tests/test_mcp_optin.py +++ b/tests/test_mcp_optin.py @@ -55,7 +55,11 @@ def temp_attach_dir(monkeypatch, tmp_path): import mcp.config - monkeypatch.setattr(mcp.config, "get_attach_dir", lambda: attach_dir) + # Patch _attach_dir_path so both get_attach_dir and get_mcp_config_path use temp dir + monkeypatch.setattr(mcp.config, "_attach_dir_path", lambda: attach_dir) + + # Disable Weaviate to avoid connection errors during tests + monkeypatch.setenv("MEM_BACKEND", "none") return attach_dir diff --git a/tests/test_mcp_proxy_quota.py b/tests/test_mcp_proxy_quota.py index c192600..a7a01cb 100644 --- a/tests/test_mcp_proxy_quota.py +++ b/tests/test_mcp_proxy_quota.py @@ -57,6 +57,9 @@ def temp_attach_dir(monkeypatch, tmp_path): import mcp.config import mcp.quota + # Patch _attach_dir_path for is_mcp_enabled() and get_mcp_config_path() + monkeypatch.setattr(mcp.config, "_attach_dir_path", lambda: attach_dir) + # Also patch get_attach_dir for backward compatibility monkeypatch.setattr(mcp.config, "get_attach_dir", lambda: attach_dir) monkeypatch.setattr(mcp.quota, "get_attach_dir", lambda: attach_dir) monkeypatch.setattr(audit.sqlite, "get_attach_dir", lambda: attach_dir) From 44d0906fa06ca010d12d78e6582bd9590c64daa5 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Wed, 7 Jan 2026 18:09:03 +0500 Subject: [PATCH 3/5] renamed namespaces of new packages --- .gitignore | 1 - CLAUDE.md | 148 ++++++++++++++++++ {audit => attach/audit}/__init__.py | 0 {audit => attach/audit}/sqlite.py | 2 +- attach/cli_claude.py | 2 +- attach/cli_mcp.py | 2 +- {console => attach/console}/__init__.py | 0 {console => attach/console}/router.py | 4 +- {console => attach/console}/static/app.js | 0 {console => attach/console}/static/index.html | 0 {console => attach/console}/static/style.css | 0 attach/gateway.py | 8 +- {mcp => attach/mcp}/__init__.py | 0 {mcp => attach/mcp}/config.py | 0 {mcp => attach/mcp}/proxy.py | 6 +- {mcp => attach/mcp}/quota.py | 4 +- {mcp => attach/mcp}/router.py | 4 +- pyproject.toml | 2 +- tests/test_console_auth.py | 10 +- tests/test_mcp_optin.py | 4 +- tests/test_mcp_proxy_quota.py | 38 +++-- 21 files changed, 190 insertions(+), 45 deletions(-) create mode 100644 CLAUDE.md rename {audit => attach/audit}/__init__.py (100%) rename {audit => attach/audit}/sqlite.py (99%) rename {console => attach/console}/__init__.py (100%) rename {console => attach/console}/router.py (96%) rename {console => attach/console}/static/app.js (100%) rename {console => attach/console}/static/index.html (100%) rename {console => attach/console}/static/style.css (100%) rename {mcp => attach/mcp}/__init__.py (100%) rename {mcp => attach/mcp}/config.py (100%) rename {mcp => attach/mcp}/proxy.py (97%) rename {mcp => attach/mcp}/quota.py (96%) rename {mcp => attach/mcp}/router.py (95%) diff --git a/.gitignore b/.gitignore index e4a619d..2438d9a 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,4 @@ venv/ .venv/ __pyo3*.so *.egg-info/ -CLAUDE.md .claude/ \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..16a8e00 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,148 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Attach Gateway is a Python-based **OIDC/DID identity sidecar** for LLM engines (Ollama, vLLM) and multi-agent frameworks. It provides OIDC/DID-JWT authentication as the core, with optional A2A handoff, pluggable memory backends, and usage/quota features that can be enabled when needed. + +## Design / Code Philosophy + +**Attach is an identity sidecar first.** Many users install and run Attach only to enforce OIDC/DID auth and stamp identity/session headers in front of an LLM engine. That "OIDC sidecar" path must remain: +- fast to start +- low overhead +- stable and backwards compatible + +**Opt-in, not mandatory.** Any feature beyond OIDC/DID auth (memory backends, A2A routing, quotas, metering, MCP gateway, etc.) must be: +- gated behind explicit flags/config (env vars, config files, or CLI subcommands) +- disabled by default +- "lazy-loaded" (avoid importing heavy modules or starting background tasks unless the feature is enabled) + +**No surprise dependencies.** +- Keep the base install lean. +- Add heavyweight or optional integrations as extras (e.g. `.[quota]`, `.[usage]`, `.[full]`) and ensure the default path does not require them. +- Missing optional env vars must not crash the gateway; prefer graceful fallbacks with clear warnings. + +**Local-first and privacy-respecting by default.** +- No phone-home behavior unless explicitly enabled. +- Default logs/metrics should be local-only. If remote metering is supported, it must be opt-in and non-fatal. + +**Safe-by-default changes.** +- New routes/middlewares must not weaken authentication requirements for existing endpoints. +- Avoid breaking changes to required environment variables or the default startup flow. + +## Build and Development Commands + +```bash +# Install from source with all dev dependencies +pip install -e ".[dev,full]" + +# Run the gateway (development) +uvicorn main:app --port 8080 --reload + +# Run the gateway (CLI) +attach-gateway --port 8080 + +# Format code +black . +isort . + +# Run all tests +pytest tests/ + +# Run specific test file +pytest tests/test_jwt_middleware.py -v + +# Run tests with coverage +pytest tests/ --cov=. +``` + +## Required Environment Variables + +```bash +OIDC_ISSUER=https://your-domain.auth0.com/ # OIDC provider issuer URL +OIDC_AUD=your-api-identifier # Expected JWT audience claim +``` + +Optional variables: +- `ENGINE_URL`: LLM engine endpoint (default: `http://localhost:11434`) +- `MEM_BACKEND`: Memory backend - `none` (default), `weaviate`, or `sakana` +- `WEAVIATE_URL`: Required if `MEM_BACKEND=weaviate` +- `MAX_TOKENS_PER_MIN`: Enables token quota middleware +- `USAGE_METERING`: `null` (default), `prometheus`, or `openmeter` + - Note: metering backends must remain optional; missing keys/config should gracefully fall back to `null` behavior. + +## Architecture + +### Request Flow +``` +Client Request → middleware/auth.py (JWT validation) + → middleware/session.py (session ID generation) + → proxy/engine.py or a2a/routes.py + → Memory backend (fire-and-forget write) +``` + +### Key Modules +- **auth/**: OIDC JWT and DID token verification (`oidc.py`, `did.py`) +- **middleware/**: Stateless header processing - auth extraction, session stamping, quota enforcement +- **proxy/**: Engine-agnostic HTTP streaming proxy to Ollama/vLLM +- **a2a/**: Agent-to-agent task routing (`/a2a/tasks/send`, `/a2a/tasks/status`) +- **mem/**: Pluggable memory backends with factory pattern +- **usage/**: Token metering backends (Prometheus, OpenMeter) + +### Authentication Dispatch +`auth/__init__.py` routes tokens by format: +- 2 dots → OIDC JWT (`auth/oidc.py`) - RS256/ES256 only +- 3+ dots → DID token (`auth/did.py`) - did:key or did:pkh + +## Code Conventions + +- Python 3.10+ with type hints everywhere using `from __future__ import annotations` +- All FastAPI routes must be `async`; wrap blocking I/O in `loop.run_in_executor` +- Use `aiter_bytes()` for streaming responses (constant memory) +- Module size limit: 400 lines; extract helpers if larger +- Format with `black`, sort imports with `isort` +- Commit messages: Conventional Commits (`feat:`, `fix:`, `docs:`) + +## Security Requirements + +- Reject HS256 JWTs; only accept RS256/ES256 +- Enforce `aud` and `exp` claims (60s clock skew allowed) +- Session IDs: `sha256(user.sub + user-agent)` - non-guessable +- Log only first 8 chars of JWT `sub` claim; never log full tokens + +## Testing + +- Use `pytest-asyncio` for async tests +- Mock network calls with `httpx.MockTransport` +- Test files in `tests/` directory +- Config: `pytest.ini` sets `pythonpath = .` + +## Starting Local Services + +```bash +# Start Weaviate memory backend +docker run --rm -d -p 6666:8080 \ + -e AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true \ + semitechnologies/weaviate:1.30.5 + +# Or use the helper script +python script/start_weaviate.py +``` + +## Multi-Agent Demo + +```bash +# Terminal 1: Gateway +uvicorn main:app --port 8080 + +# Terminal 2: Planner agent +uvicorn examples.agents.planner:app --port 8100 + +# Terminal 3: Coder agent +uvicorn examples.agents.coder:app --port 8101 + +# Terminal 4: Demo UI +cd examples/static && python -m http.server 9000 +# Open http://localhost:9000/demo.html +``` diff --git a/audit/__init__.py b/attach/audit/__init__.py similarity index 100% rename from audit/__init__.py rename to attach/audit/__init__.py diff --git a/audit/sqlite.py b/attach/audit/sqlite.py similarity index 99% rename from audit/sqlite.py rename to attach/audit/sqlite.py index 6e749cc..1fe53f9 100644 --- a/audit/sqlite.py +++ b/attach/audit/sqlite.py @@ -34,7 +34,7 @@ from pathlib import Path from typing import Any, Optional -from mcp.config import get_attach_dir +from attach.mcp.config import get_attach_dir log = logging.getLogger(__name__) diff --git a/attach/cli_claude.py b/attach/cli_claude.py index 1fd3a22..44d7a25 100644 --- a/attach/cli_claude.py +++ b/attach/cli_claude.py @@ -12,7 +12,7 @@ import click -from mcp.config import get_enabled_servers +from attach.mcp.config import get_enabled_servers @click.group(name="claude") diff --git a/attach/cli_mcp.py b/attach/cli_mcp.py index 7d2ae35..d24c9ac 100644 --- a/attach/cli_mcp.py +++ b/attach/cli_mcp.py @@ -12,7 +12,7 @@ import click -from mcp.config import ( +from attach.mcp.config import ( get_mcp_config_path, load_mcp_config, save_mcp_config, diff --git a/console/__init__.py b/attach/console/__init__.py similarity index 100% rename from console/__init__.py rename to attach/console/__init__.py diff --git a/console/router.py b/attach/console/router.py similarity index 96% rename from console/router.py rename to attach/console/router.py index b81b062..8270c1b 100644 --- a/console/router.py +++ b/attach/console/router.py @@ -22,8 +22,8 @@ from fastapi import APIRouter, Request from fastapi.responses import FileResponse, JSONResponse -from audit.sqlite import overview_stats, query_mcp_events -from mcp.config import get_enabled_servers +from attach.audit.sqlite import overview_stats, query_mcp_events +from attach.mcp.config import get_enabled_servers log = logging.getLogger(__name__) diff --git a/console/static/app.js b/attach/console/static/app.js similarity index 100% rename from console/static/app.js rename to attach/console/static/app.js diff --git a/console/static/index.html b/attach/console/static/index.html similarity index 100% rename from console/static/index.html rename to attach/console/static/index.html diff --git a/console/static/style.css b/attach/console/static/style.css similarity index 100% rename from console/static/style.css rename to attach/console/static/style.css diff --git a/attach/gateway.py b/attach/gateway.py index 7cee8a1..47445a0 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -118,7 +118,7 @@ async def lifespan(app: FastAPI): # Initialize MCP audit DB if MCP is enabled if getattr(app.state, "mcp_enabled", False): - from audit.sqlite import init_db + from attach.audit.sqlite import init_db init_db() @@ -194,14 +194,14 @@ async def auth_config(): app.include_router(mem_router) # Conditionally mount MCP and console routers (opt-in) - from mcp.config import is_mcp_enabled + from attach.mcp.config import is_mcp_enabled mcp_enabled = is_mcp_enabled() app.state.mcp_enabled = mcp_enabled if mcp_enabled: - from console.router import router as console_router - from mcp.router import router as mcp_router + from attach.console.router import router as console_router + from attach.mcp.router import router as mcp_router app.include_router(mcp_router) app.include_router(console_router) diff --git a/mcp/__init__.py b/attach/mcp/__init__.py similarity index 100% rename from mcp/__init__.py rename to attach/mcp/__init__.py diff --git a/mcp/config.py b/attach/mcp/config.py similarity index 100% rename from mcp/config.py rename to attach/mcp/config.py diff --git a/mcp/proxy.py b/attach/mcp/proxy.py similarity index 97% rename from mcp/proxy.py rename to attach/mcp/proxy.py index 8c47c74..6580b4c 100644 --- a/mcp/proxy.py +++ b/attach/mcp/proxy.py @@ -13,9 +13,9 @@ import httpx -from audit.sqlite import insert_mcp_event -from mcp.config import get_enabled_servers, get_server_headers -from mcp.quota import check_quota, record_tool_call +from attach.audit.sqlite import insert_mcp_event +from attach.mcp.config import get_enabled_servers, get_server_headers +from attach.mcp.quota import check_quota, record_tool_call log = logging.getLogger(__name__) diff --git a/mcp/quota.py b/attach/mcp/quota.py similarity index 96% rename from mcp/quota.py rename to attach/mcp/quota.py index 7b6a027..1088e7c 100644 --- a/mcp/quota.py +++ b/attach/mcp/quota.py @@ -30,8 +30,8 @@ from pathlib import Path from typing import Any, Optional -from audit.sqlite import atomic_increment_and_get_quota_count, increment_quota_count -from mcp.config import get_attach_dir +from attach.audit.sqlite import atomic_increment_and_get_quota_count, increment_quota_count +from attach.mcp.config import get_attach_dir log = logging.getLogger(__name__) diff --git a/mcp/router.py b/attach/mcp/router.py similarity index 95% rename from mcp/router.py rename to attach/mcp/router.py index e25cedb..e0cfd65 100644 --- a/mcp/router.py +++ b/attach/mcp/router.py @@ -14,8 +14,8 @@ from fastapi import APIRouter, Request from fastapi.responses import JSONResponse -from mcp.config import get_enabled_servers -from mcp.proxy import proxy_mcp_request +from attach.mcp.config import get_enabled_servers +from attach.mcp.proxy import proxy_mcp_request log = logging.getLogger(__name__) diff --git a/pyproject.toml b/pyproject.toml index feb6a95..d27ef50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ attach-gateway = "attach.__main__:main" # Include your existing modules with cleaner targeting [tool.hatch.build.targets.wheel] -packages = ["attach", "auth", "middleware", "mem", "proxy", "a2a", "attach_pydid", "usage", "utils", "logs", "mcp", "audit", "console"] +packages = ["attach", "auth", "middleware", "mem", "proxy", "a2a", "attach_pydid", "usage", "utils", "logs"] # Dynamic version from attach/__init__.py diff --git a/tests/test_console_auth.py b/tests/test_console_auth.py index 0694eab..f7349ac 100644 --- a/tests/test_console_auth.py +++ b/tests/test_console_auth.py @@ -53,18 +53,18 @@ def temp_attach_dir(monkeypatch, tmp_path): attach_dir = tmp_path / "attach" attach_dir.mkdir() - import audit.sqlite - import mcp.config + import attach.audit.sqlite + import attach.mcp.config - monkeypatch.setattr(mcp.config, "get_attach_dir", lambda: attach_dir) - monkeypatch.setattr(audit.sqlite, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.config, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.audit.sqlite, "get_attach_dir", lambda: attach_dir) # Create mcp.json to enable MCP mcp_config = {"version": 1, "servers": {}} (attach_dir / "mcp.json").write_text(json.dumps(mcp_config)) # Initialize audit DB - from audit.sqlite import init_db + from attach.audit.sqlite import init_db init_db() diff --git a/tests/test_mcp_optin.py b/tests/test_mcp_optin.py index f7f6d8f..eea9a2b 100644 --- a/tests/test_mcp_optin.py +++ b/tests/test_mcp_optin.py @@ -53,10 +53,10 @@ def temp_attach_dir(monkeypatch, tmp_path): attach_dir = tmp_path / "attach" attach_dir.mkdir() - import mcp.config + import attach.mcp.config # Patch _attach_dir_path so both get_attach_dir and get_mcp_config_path use temp dir - monkeypatch.setattr(mcp.config, "_attach_dir_path", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.config, "_attach_dir_path", lambda: attach_dir) # Disable Weaviate to avoid connection errors during tests monkeypatch.setenv("MEM_BACKEND", "none") diff --git a/tests/test_mcp_proxy_quota.py b/tests/test_mcp_proxy_quota.py index a7a01cb..a8f7383 100644 --- a/tests/test_mcp_proxy_quota.py +++ b/tests/test_mcp_proxy_quota.py @@ -53,16 +53,16 @@ def temp_attach_dir(monkeypatch, tmp_path): attach_dir = tmp_path / "attach" attach_dir.mkdir() - import audit.sqlite - import mcp.config - import mcp.quota + import attach.audit.sqlite + import attach.mcp.config + import attach.mcp.quota # Patch _attach_dir_path for is_mcp_enabled() and get_mcp_config_path() - monkeypatch.setattr(mcp.config, "_attach_dir_path", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.config, "_attach_dir_path", lambda: attach_dir) # Also patch get_attach_dir for backward compatibility - monkeypatch.setattr(mcp.config, "get_attach_dir", lambda: attach_dir) - monkeypatch.setattr(mcp.quota, "get_attach_dir", lambda: attach_dir) - monkeypatch.setattr(audit.sqlite, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.config, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.quota, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.audit.sqlite, "get_attach_dir", lambda: attach_dir) return attach_dir @@ -98,7 +98,7 @@ async def test_mcp_proxy_forwards_request(temp_attach_dir, fake_mcp_server): (temp_attach_dir / "mcp.json").write_text(json.dumps(mcp_config)) # Initialize audit DB - from audit.sqlite import init_db + from attach.audit.sqlite import init_db init_db() @@ -108,9 +108,9 @@ async def test_mcp_proxy_forwards_request(temp_attach_dir, fake_mcp_server): # Mock the httpx client to return fake upstream response from unittest.mock import Mock - import mcp.proxy + import attach.mcp.proxy - original_client = mcp.proxy.httpx.AsyncClient + original_client = attach.mcp.proxy.httpx.AsyncClient class MockAsyncClient: def __init__(self, *args, **kwargs): @@ -134,9 +134,7 @@ async def post(self, url, json, headers): } return mock_response - import mcp.proxy - - mcp.proxy.httpx.AsyncClient = MockAsyncClient + attach.mcp.proxy.httpx.AsyncClient = MockAsyncClient try: async with AsyncClient( @@ -153,7 +151,7 @@ async def post(self, url, json, headers): data = response.json() assert data["result"]["status"] == "ok" finally: - mcp.proxy.httpx.AsyncClient = original_client + attach.mcp.proxy.httpx.AsyncClient = original_client @pytest.mark.asyncio @@ -177,7 +175,7 @@ async def test_mcp_quota_enforcement(temp_attach_dir): (temp_attach_dir / "mcp_policy.json").write_text(json.dumps(policy_config)) # Initialize audit DB - from audit.sqlite import init_db + from attach.audit.sqlite import init_db init_db() @@ -187,7 +185,7 @@ async def test_mcp_quota_enforcement(temp_attach_dir): # Mock httpx client from unittest.mock import Mock - import mcp.proxy + import attach.mcp.proxy class MockAsyncClient: def __init__(self, *args, **kwargs): @@ -211,8 +209,8 @@ async def post(self, url, json, headers): } return mock_response - original_client = mcp.proxy.httpx.AsyncClient - mcp.proxy.httpx.AsyncClient = MockAsyncClient + original_client = attach.mcp.proxy.httpx.AsyncClient + attach.mcp.proxy.httpx.AsyncClient = MockAsyncClient try: async with AsyncClient( @@ -251,7 +249,7 @@ async def post(self, url, json, headers): assert "error" in data2 assert data2["error"]["code"] == -32029 # quota exceeded code finally: - mcp.proxy.httpx.AsyncClient = original_client + attach.mcp.proxy.httpx.AsyncClient = original_client @pytest.mark.asyncio @@ -262,7 +260,7 @@ async def test_mcp_server_not_found(temp_attach_dir): (temp_attach_dir / "mcp.json").write_text(json.dumps(mcp_config)) # Initialize audit DB - from audit.sqlite import init_db + from attach.audit.sqlite import init_db init_db() From 51f816bd484fa479a526168a7f7c01a240ebbab3 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Sun, 18 Jan 2026 01:40:52 +0500 Subject: [PATCH 4/5] test: add integration tests for MCP gateway, auth, and quota MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Comprehensive integration tests covering: - MCP endpoint authentication requirements - Multi-tenant user isolation in audit logs - Per-user quota enforcement and isolation - Console public landing vs protected API auth model - MCP opt-in/opt-out behavior - End-to-end Ollama proxy with JWT auth - Gateway health checks and CORS 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/test_integration_e2e.py | 209 ++++++++++++ tests/test_integration_mcp_auth.py | 493 +++++++++++++++++++++++++++++ 2 files changed, 702 insertions(+) create mode 100644 tests/test_integration_e2e.py create mode 100644 tests/test_integration_mcp_auth.py diff --git a/tests/test_integration_e2e.py b/tests/test_integration_e2e.py new file mode 100644 index 0000000..123db38 --- /dev/null +++ b/tests/test_integration_e2e.py @@ -0,0 +1,209 @@ +""" +End-to-end integration tests with real Ollama backend. + +These tests require: +1. Ollama running on localhost:11434 +2. A model available (e.g., tinyllama) + +Skip these tests if Ollama is not available. +""" + +import os + +import httpx +import pytest +from httpx import ASGITransport, AsyncClient + +# Check if Ollama is available +def is_ollama_available(): + try: + resp = httpx.get("http://localhost:11434/api/tags", timeout=2) + return resp.status_code == 200 + except Exception: + return False + + +OLLAMA_AVAILABLE = is_ollama_available() +SKIP_REASON = "Ollama not running on localhost:11434" + +# Set required env vars +os.environ["OIDC_ISSUER"] = "https://test.auth0.com/" +os.environ["OIDC_AUD"] = "test-api" +os.environ["MEM_BACKEND"] = "none" +os.environ["ENGINE_URL"] = "http://localhost:11434" + +from jose import JWTError + +import auth.oidc +import middleware.auth +from attach.gateway import create_app + +VALID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJlMmUtdXNlciJ9.test" + + +@pytest.fixture(autouse=True) +def stub_verify_jwt(monkeypatch): + """Stub JWT verification.""" + + def fake_verify_sync(token: str, *, leeway: int = 60): + if token == VALID_TOKEN: + return {"sub": "e2e-user"} + raise JWTError("invalid token") + + async def fake_verify_async(token: str, *, leeway: int = 60): + return fake_verify_sync(token, leeway=leeway) + + monkeypatch.setattr(auth.oidc, "verify_jwt", fake_verify_sync) + monkeypatch.setattr(auth.oidc, "verify_jwt_with_exchange", fake_verify_async) + monkeypatch.setattr(middleware.auth, "verify_jwt", fake_verify_sync) + monkeypatch.setattr(middleware.auth, "verify_jwt_with_exchange", fake_verify_async) + + +@pytest.mark.skipif(not OLLAMA_AVAILABLE, reason=SKIP_REASON) +@pytest.mark.asyncio +async def test_e2e_ollama_proxy_chat(): + """Test end-to-end chat completion through gateway to Ollama. + + Note: Gateway proxies to /api/chat endpoint (Ollama native format). + """ + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test", timeout=60.0 + ) as client: + response = await client.post( + "/api/chat", + json={ + "model": "tinyllama", + "messages": [{"role": "user", "content": "Say hi"}], + "stream": False, + }, + headers={"Authorization": f"Bearer {VALID_TOKEN}"}, + ) + + # Should succeed + assert response.status_code == 200 + data = response.json() + + # Response structure - gateway may transform to OpenAI format + assert "message" in data or "response" in data or "choices" in data + + +@pytest.mark.skipif(not OLLAMA_AVAILABLE, reason=SKIP_REASON) +@pytest.mark.asyncio +async def test_e2e_ollama_direct(): + """Test direct Ollama API access (not through gateway proxy).""" + # This tests that Ollama itself is working + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + "http://localhost:11434/api/generate", + json={ + "model": "tinyllama", + "prompt": "Hi", + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "response" in data + + +@pytest.mark.skipif(not OLLAMA_AVAILABLE, reason=SKIP_REASON) +@pytest.mark.asyncio +async def test_e2e_ollama_tags(): + """Test listing Ollama models directly.""" + async with httpx.AsyncClient() as client: + response = await client.get("http://localhost:11434/api/tags") + + assert response.status_code == 200 + data = response.json() + assert "models" in data + + +@pytest.mark.skipif(not OLLAMA_AVAILABLE, reason=SKIP_REASON) +@pytest.mark.asyncio +async def test_e2e_requires_auth(): + """Test that proxy endpoints require authentication.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Without auth - /api/chat is the gateway's proxy endpoint + response = await client.post( + "/api/chat", + json={"model": "tinyllama", "messages": [{"role": "user", "content": "test"}], "stream": False}, + ) + assert response.status_code == 401 + + # With invalid token + response = await client.post( + "/api/chat", + json={"model": "tinyllama", "messages": [{"role": "user", "content": "test"}], "stream": False}, + headers={"Authorization": "Bearer invalid"}, + ) + assert response.status_code == 401 + + +@pytest.mark.skipif(not OLLAMA_AVAILABLE, reason=SKIP_REASON) +@pytest.mark.asyncio +async def test_e2e_session_header_injected(): + """Test that authenticated requests go through session middleware.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test", timeout=60.0 + ) as client: + response = await client.post( + "/api/chat", + json={ + "model": "tinyllama", + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + }, + headers={"Authorization": f"Bearer {VALID_TOKEN}"}, + ) + + assert response.status_code == 200 + # Session ID should be set in request state + # If the request succeeded through middleware, session was processed + + +@pytest.mark.asyncio +async def test_gateway_health_check(): + """Test gateway starts and responds to basic requests.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Auth config endpoint should work without auth + response = await client.get("/auth/config") + assert response.status_code == 200 + data = response.json() + assert "audience" in data + + +@pytest.mark.asyncio +async def test_gateway_cors_headers(): + """Test that CORS headers are properly set.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # OPTIONS request (preflight) + response = await client.options( + "/api/generate", + headers={ + "Origin": "http://localhost:9000", + "Access-Control-Request-Method": "POST", + }, + ) + + # Should allow the origin + assert response.headers.get("access-control-allow-origin") in [ + "http://localhost:9000", + "*", + ] diff --git a/tests/test_integration_mcp_auth.py b/tests/test_integration_mcp_auth.py new file mode 100644 index 0000000..c2654f8 --- /dev/null +++ b/tests/test_integration_mcp_auth.py @@ -0,0 +1,493 @@ +""" +Integration tests for MCP Gateway with auth, quota, and multi-tenant features. + +Tests verify: +1. MCP endpoints require JWT authentication +2. MCP quota enforcement across multiple users +3. Console API auth model (public landing, protected API) +4. Multi-tenant user isolation in audit logs +5. Token exchange flow (Descope/Auth0) +""" + +import json +import os +from pathlib import Path + +import pytest +from httpx import ASGITransport, AsyncClient + +# Set required env vars before importing gateway +os.environ["OIDC_ISSUER"] = "https://test.auth0.com/" +os.environ["OIDC_AUD"] = "test-api" +os.environ["ATTACH_ENABLE_MCP"] = "true" +os.environ["MEM_BACKEND"] = "none" + +from jose import JWTError + +import auth.oidc +import middleware.auth +from attach.gateway import create_app + +# Test tokens for different users +USER_ALICE_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhbGljZSJ9.test1" +USER_BOB_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJib2IifQ.test2" +USER_CHARLIE_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJjaGFybGllIn0.test3" +INVALID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiYWQifQ.invalid" + + +@pytest.fixture(autouse=True) +def stub_verify_jwt(monkeypatch): + """Stub JWT verification to allow test tokens.""" + + def fake_verify_sync(token: str, *, leeway: int = 60): + token_map = { + USER_ALICE_TOKEN: {"sub": "alice"}, + USER_BOB_TOKEN: {"sub": "bob"}, + USER_CHARLIE_TOKEN: {"sub": "charlie"}, + } + if token in token_map: + return token_map[token] + raise JWTError("invalid token") + + async def fake_verify_async(token: str, *, leeway: int = 60): + return fake_verify_sync(token, leeway=leeway) + + monkeypatch.setattr(auth.oidc, "verify_jwt", fake_verify_sync) + monkeypatch.setattr(auth.oidc, "verify_jwt_with_exchange", fake_verify_async) + monkeypatch.setattr(middleware.auth, "verify_jwt", fake_verify_sync) + monkeypatch.setattr(middleware.auth, "verify_jwt_with_exchange", fake_verify_async) + + +@pytest.fixture +def temp_attach_dir(monkeypatch, tmp_path): + """Override ~/.attach directory with temp dir.""" + attach_dir = tmp_path / "attach" + attach_dir.mkdir() + + import attach.audit.sqlite + import attach.mcp.config + import attach.mcp.quota + + monkeypatch.setattr(attach.mcp.config, "_attach_dir_path", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.config, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.mcp.quota, "get_attach_dir", lambda: attach_dir) + monkeypatch.setattr(attach.audit.sqlite, "get_attach_dir", lambda: attach_dir) + + return attach_dir + + +@pytest.fixture +def setup_mcp_servers(temp_attach_dir): + """Configure MCP servers for testing.""" + mcp_config = { + "version": 1, + "servers": { + "github": {"enabled": True, "url": "http://mock-github/mcp"}, + "notion": {"enabled": True, "url": "http://mock-notion/mcp"}, + "disabled-server": {"enabled": False, "url": "http://mock-disabled/mcp"}, + }, + } + (temp_attach_dir / "mcp.json").write_text(json.dumps(mcp_config)) + return mcp_config + + +@pytest.fixture +def setup_quota_policy(temp_attach_dir): + """Configure quota policy for testing.""" + policy = { + "version": 1, + "enabled": True, + "per_user_daily_tool_calls": { + "github.*": 5, + "notion.*": 3, + "*": 10, + }, + } + (temp_attach_dir / "mcp_policy.json").write_text(json.dumps(policy)) + return policy + + +@pytest.fixture +def init_audit_db(temp_attach_dir): + """Initialize audit database.""" + from attach.audit.sqlite import init_db + + init_db() + + +@pytest.fixture +def mock_mcp_upstream(monkeypatch): + """Mock MCP upstream servers.""" + from unittest.mock import Mock + + import attach.mcp.proxy + + class MockAsyncClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def post(self, url, json, headers): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.is_success = True + mock_response.json.return_value = { + "jsonrpc": "2.0", + "id": json.get("id"), + "result": {"status": "ok", "server": url}, + } + return mock_response + + original = attach.mcp.proxy.httpx.AsyncClient + attach.mcp.proxy.httpx.AsyncClient = MockAsyncClient + yield + attach.mcp.proxy.httpx.AsyncClient = original + + +# ============================================================================ +# Test: Authentication Requirements +# ============================================================================ + + +@pytest.mark.asyncio +async def test_mcp_list_requires_auth(temp_attach_dir, setup_mcp_servers, init_audit_db): + """Test that /mcp endpoint requires JWT authentication.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Without auth + response = await client.get("/mcp") + assert response.status_code == 401 + + # With invalid token + response = await client.get( + "/mcp", headers={"Authorization": f"Bearer {INVALID_TOKEN}"} + ) + assert response.status_code == 401 + + # With valid token + response = await client.get( + "/mcp", headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"} + ) + assert response.status_code == 200 + data = response.json() + assert "servers" in data + assert "github" in data["servers"] + assert "notion" in data["servers"] + # Disabled server should not appear + assert "disabled-server" not in data["servers"] + + +@pytest.mark.asyncio +async def test_mcp_proxy_requires_auth( + temp_attach_dir, setup_mcp_servers, init_audit_db, mock_mcp_upstream +): + """Test that /mcp/{server} proxy requires JWT authentication.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + payload = {"jsonrpc": "2.0", "method": "tools/list", "id": 1} + + # Without auth + response = await client.post("/mcp/github", json=payload) + assert response.status_code == 401 + + # With valid token + response = await client.post( + "/mcp/github", + json=payload, + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + assert response.status_code == 200 + + +# ============================================================================ +# Test: Multi-Tenant User Isolation +# ============================================================================ + + +@pytest.mark.asyncio +async def test_multi_tenant_audit_isolation( + temp_attach_dir, setup_mcp_servers, init_audit_db, mock_mcp_upstream +): + """Test that audit logs properly isolate different users.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Alice makes a request + await client.post( + "/mcp/github", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "github.create_issue"}, + "id": 1, + }, + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + + # Bob makes a request + await client.post( + "/mcp/notion", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "notion.create_page"}, + "id": 2, + }, + headers={"Authorization": f"Bearer {USER_BOB_TOKEN}"}, + ) + + # Query audit logs + from attach.audit.sqlite import query_mcp_events + + all_events = query_mcp_events(limit=100) + assert len(all_events) >= 2 + + alice_events = query_mcp_events(limit=100, user="alice") + bob_events = query_mcp_events(limit=100, user="bob") + + # Verify isolation + assert len(alice_events) >= 1 + assert len(bob_events) >= 1 + assert all(e["user"] == "alice" for e in alice_events) + assert all(e["user"] == "bob" for e in bob_events) + + +# ============================================================================ +# Test: Per-User Quota Enforcement +# ============================================================================ + + +@pytest.mark.asyncio +async def test_per_user_quota_isolation( + temp_attach_dir, setup_mcp_servers, setup_quota_policy, init_audit_db, mock_mcp_upstream +): + """Test that quota is enforced per-user, not globally.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Alice exhausts her notion quota (limit=3) + for i in range(3): + response = await client.post( + "/mcp/notion", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "notion.create_page"}, + "id": i + 1, + }, + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "result" in data # Should succeed + + # Alice's 4th request should be denied + response = await client.post( + "/mcp/notion", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "notion.create_page"}, + "id": 100, + }, + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32029 # quota exceeded + + # But Bob should still be able to use notion (his own quota) + response = await client.post( + "/mcp/notion", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "notion.create_page"}, + "id": 200, + }, + headers={"Authorization": f"Bearer {USER_BOB_TOKEN}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "result" in data # Should succeed (Bob's own quota) + + +@pytest.mark.asyncio +async def test_quota_glob_patterns( + temp_attach_dir, setup_mcp_servers, setup_quota_policy, init_audit_db, mock_mcp_upstream +): + """Test that quota glob patterns work correctly. + + Note: Quota counts per exact tool name, not per pattern. + So calling github.create_issue 5 times = 5 calls toward github.* limit. + """ + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # GitHub tools have limit=5, call same tool 5 times + for i in range(5): + response = await client.post( + "/mcp/github", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "github.create_issue"}, + "id": i + 1, + }, + headers={"Authorization": f"Bearer {USER_CHARLIE_TOKEN}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "result" in data + + # 6th call to same tool should be denied + response = await client.post( + "/mcp/github", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "github.create_issue"}, + "id": 999, + }, + headers={"Authorization": f"Bearer {USER_CHARLIE_TOKEN}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32029 + + +# ============================================================================ +# Test: Console Auth Model +# ============================================================================ + + +@pytest.mark.asyncio +async def test_console_public_vs_protected(temp_attach_dir, setup_mcp_servers, init_audit_db): + """Test console auth model: public landing, protected API.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + # Public endpoints (no auth required) + response = await client.get("/console") + assert response.status_code == 200 + + response = await client.get("/console/static/app.js") + # 200 or 404 (file may not exist), but NOT 401 + assert response.status_code in [200, 404] + + # Protected API endpoints + response = await client.get("/console/api/overview") + assert response.status_code == 401 + + response = await client.get("/console/api/events") + assert response.status_code == 401 + + response = await client.get("/console/api/servers") + assert response.status_code == 401 + + # With valid auth + response = await client.get( + "/console/api/overview", + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "calls_today" in data + assert "denies_today" in data + + +# ============================================================================ +# Test: Disabled MCP (Opt-Out) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_mcp_disabled_no_routes(monkeypatch, tmp_path): + """Test that MCP routes are not mounted when disabled.""" + attach_dir = tmp_path / "attach_disabled" + attach_dir.mkdir() + + import attach.mcp.config + + monkeypatch.setattr(attach.mcp.config, "_attach_dir_path", lambda: attach_dir) + monkeypatch.delenv("ATTACH_ENABLE_MCP", raising=False) + + # No mcp.json, no env var = MCP disabled + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.get( + "/mcp", headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"} + ) + assert response.status_code == 404 + + response = await client.get("/console") + assert response.status_code == 404 + + +# ============================================================================ +# Test: Server Not Found / Disabled +# ============================================================================ + + +@pytest.mark.asyncio +async def test_mcp_server_not_found( + temp_attach_dir, setup_mcp_servers, init_audit_db +): + """Test requesting a non-existent MCP server.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/mcp/nonexistent", + json={"jsonrpc": "2.0", "method": "test", "id": 1}, + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + assert response.status_code == 404 + data = response.json() + assert "error" in data + + +@pytest.mark.asyncio +async def test_mcp_disabled_server( + temp_attach_dir, setup_mcp_servers, init_audit_db +): + """Test requesting a disabled MCP server.""" + app = create_app() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/mcp/disabled-server", + json={"jsonrpc": "2.0", "method": "test", "id": 1}, + headers={"Authorization": f"Bearer {USER_ALICE_TOKEN}"}, + ) + assert response.status_code == 404 + data = response.json() + assert "error" in data From 7b0e42f74fad5f9b9812756235c7196c71dc2ad2 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Sun, 18 Jan 2026 01:45:28 +0500 Subject: [PATCH 5/5] chore: remove IMPLEMENTATION_SUMMARY.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implementation details are documented in the PR description. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- IMPLEMENTATION_SUMMARY.md | 344 -------------------------------------- 1 file changed, 344 deletions(-) delete mode 100644 IMPLEMENTATION_SUMMARY.md diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md deleted file mode 100644 index dfc89d9..0000000 --- a/IMPLEMENTATION_SUMMARY.md +++ /dev/null @@ -1,344 +0,0 @@ -# Plan 1 MCP Gateway Implementation Summary - -## Overview - -Successfully implemented Plan 1 - an OSS-friendly, local-first MCP Gateway layer inside Attach Gateway that is OPT-IN and backwards compatible with existing OIDC/JWT sidecar functionality. - -## Implementation Status: ✅ COMPLETE - -All 17 planned deliverables have been implemented and tested. - ---- - -## Key Features Delivered - -### 1. MCP Configuration & Lifecycle -- **Module**: `mcp/config.py` -- **Config file**: `~/.attach/mcp.json` -- **CLI commands**: - - `attach-gateway mcp list` - - `attach-gateway mcp add [--header ...]` - - `attach-gateway mcp enable/disable ` - - `attach-gateway mcp remove ` -- **Features**: - - HTTP upstream support (stdio spawning not in MVP) - - Header resolution with `env:VARNAME` syntax - - Never logs resolved secrets - -### 2. MCP Reverse Proxy Endpoints -- **Module**: `mcp/router.py`, `mcp/proxy.py` -- **Routes**: - - `POST /mcp/{server}` - Forwards JSON-RPC to configured upstream - - `GET /mcp` - Returns list of servers + enabled state -- **Security**: All `/mcp/*` endpoints require Bearer JWT authentication -- **Error handling**: JSON-RPC error responses for quota/timeout/upstream failures - -### 3. Audit Logging (Local SQLite) -- **Module**: `audit/sqlite.py` -- **Database**: `~/.attach/attach.db` -- **Tables**: - - `mcp_events` - Stores metadata: ts, user, server, method, tool, allowed, latency_ms, error - - `mcp_counters` - Quota usage tracking by date_utc/user/tool -- **Privacy**: No request/response bodies stored, only metadata - -### 4. Quota Enforcement -- **Module**: `mcp/quota.py` -- **Policy file**: `~/.attach/mcp_policy.json` -- **Features**: - - Per-user daily call limits - - Glob pattern matching for tool names (fnmatch) - - Only enforces on `tools/call` JSON-RPC method - - Denied calls return JSON-RPC error (HTTP 200) and are logged - -### 5. Console UI -- **Module**: `console/router.py`, `console/static/` -- **Routes**: - - `GET /console` - Static HTML (unauthenticated, no sensitive data) - - `GET /console/static/*` - Static assets (unauthenticated) - - `GET /console/api/overview` - Statistics (JWT protected) - - `GET /console/api/events` - Event log (JWT protected) - - `GET /console/api/servers` - Server list (JWT protected) -- **Pages**: - - Overview: calls today, denies today, top tools, top users - - Events table: timestamp, user, server, method, tool, allowed, latency, error - - Servers list: enabled servers with URLs -- **Auth model**: Landing page public, API endpoints require JWT - -### 6. Claude Code Installer Helper -- **Module**: `attach/cli_claude.py` -- **Command**: `attach-gateway claude install [--project .] [--bearer ] [--write-file]` -- **Modes**: - - Default: Prints `claude mcp add` commands (recommended, no schema brittleness) - - `--write-file`: Writes `.mcp.json` directly (experimental, schema may change) - - `--bearer`: Includes Authorization header if Claude supports it -- **Safety**: Avoids schema brittleness by preferring command generation - -### 7. OpenMeter Non-Fatal Fix -- **Module**: `usage/factory.py` -- **Change**: `USAGE_METERING=openmeter` without `OPENMETER_API_KEY` now logs warning and falls back to `NullUsageBackend` instead of crashing -- **Impact**: Gateway always starts successfully; missing optional config is non-fatal - ---- - -## Modules Added - -### New Top-Level Packages -- `mcp/` - MCP gateway core functionality - - `__init__.py` - - `config.py` - Configuration management - - `proxy.py` - HTTP forwarding logic - - `router.py` - FastAPI routes - - `quota.py` - Quota enforcement - -- `audit/` - Audit logging - - `__init__.py` - - `sqlite.py` - SQLite-based event logging - -- `console/` - Web console UI - - `__init__.py` - - `router.py` - FastAPI routes - - `static/` - - `index.html` - Single-page console app - - `app.js` - Client-side logic with localStorage JWT - - `style.css` - Styling - -### Modified Core Modules -- `attach/gateway.py` - Conditional MCP/console router mounting (opt-in) -- `attach/__main__.py` - Added subcommand support (backward compatible) -- `attach/cli_mcp.py` - MCP server management commands -- `attach/cli_claude.py` - Claude Code integration helper -- `middleware/auth.py` - Added `/console` and `/console/static/*` exclusions -- `middleware/session.py` - Added `/console` and `/console/static/*` exclusions -- `usage/factory.py` - OpenMeter non-fatal fallback -- `pyproject.toml` - Added `mcp`, `audit`, `console` to packages list - ---- - -## Tests Added - -All tests follow existing patterns from `tests/test_jwt_middleware.py`: - -1. **`tests/test_mcp_optin.py`** - - Tests MCP routes are 404 without opt-in - - Tests MCP routes available with `ATTACH_ENABLE_MCP=true` - - Tests MCP routes available with `~/.attach/mcp.json` present - -2. **`tests/test_mcp_proxy_quota.py`** - - Tests MCP proxy forwards requests to upstream - - Tests quota enforcement denies after limit exceeded - - Tests unknown server returns 404 with JSON-RPC error - -3. **`tests/test_console_auth.py`** - - Tests `/console` accessible without auth - - Tests `/console/static/*` accessible without auth - - Tests `/console/api/*` requires JWT (401 without, 200 with valid token) - -4. **`tests/test_openmeter_fallback.py`** - - Tests `USAGE_METERING=openmeter` without key doesn't crash - - Tests falls back to `NullUsageBackend` - - Tests `USAGE_METERING=openmeter` with key uses `OpenMeterBackend` - ---- - -## Opt-In Mechanism - -MCP Gateway is enabled if either: -1. Environment variable `ATTACH_ENABLE_MCP=true` is set, OR -2. Config file `~/.attach/mcp.json` exists - -When disabled: -- `/mcp` and `/console` routes are NOT mounted (404) -- No MCP-related imports or initialization -- Zero performance impact on core OIDC sidecar functionality - ---- - -## Security Model - -### Authentication Requirements -| Path | Auth Required | Notes | -|------|---------------|-------| -| `/mcp/*` | ✅ Yes | Bearer JWT required | -| `/console` | ❌ No | Static HTML, no sensitive data | -| `/console/static/*` | ❌ No | CSS/JS/images only | -| `/console/api/*` | ✅ Yes | Bearer JWT required | - -### Privacy & Security Features -- JWT validation uses existing OIDC/DID infrastructure -- Client `Authorization` header NOT forwarded to upstream MCP servers -- Upstream headers configured separately with `env:` support -- Audit logs contain metadata only (no bodies) -- All data local by default (no phone-home) -- User sub logged as first 8 chars only in some places - ---- - -## Backward Compatibility - -### ✅ Guaranteed Safe -- Default behavior unchanged: `attach-gateway --port 8080` works exactly as before -- Existing routes unaffected: `/api/chat`, `/a2a/*`, `/mem/*` unchanged -- Auth middleware preserves exact behavior for existing paths -- No new required dependencies -- No breaking changes to environment variables - -### CLI Changes (Backward Compatible) -- `attach-gateway --port 8080` still runs server (default command) -- New subcommands: `attach-gateway mcp ...`, `attach-gateway claude ...` -- Uses `click.Group(invoke_without_command=True)` pattern - ---- - -## README Updates - -Added comprehensive section: "Claude Code + MCP Gateway (Local-First, 2-Minute Setup)" - -Includes: -- Feature overview and benefits -- Quick setup instructions -- CLI command examples -- Claude Code integration guide -- Console UI usage -- How it works diagram -- Security notes - -Location: Between main quickstart and "Use in your project" sections - ---- - -## Definition of Done - Verification - -✅ **Default behavior unchanged** - MCP disabled by default -✅ **MCP explicitly enabled** - Via env var or config file -✅ **Routes mounted conditionally** - `/mcp` and `/console` only when enabled -✅ **JWT protection maintained** - No auth weakening for `/mcp` -✅ **Console auth model secure** - Public landing page, protected API -✅ **Audit logs working** - SQLite metadata storage -✅ **Quota enforcement working** - Glob patterns, daily limits, JSON-RPC errors -✅ **Claude installer helper** - Prints valid commands -✅ **OpenMeter never crashes** - Graceful fallback to NullUsageBackend -✅ **Tests written** - 4 test files covering all features -✅ **README updated** - Stars-magnet quickstart section added -✅ **No new deps** - Uses stdlib + existing FastAPI/httpx/click -✅ **Syntax validated** - All modules compile successfully - ---- - -## Known Limitations (As Designed for MVP) - -1. **HTTP transport only** - No stdio process spawning yet -2. **Single gateway instance** - Quota counters are local (not distributed) -3. **UTC day boundary** - Quota resets at midnight UTC regardless of timezone -4. **No response body caching** - Audit logs store metadata only -5. **Basic glob patterns** - Uses fnmatch, not full regex -6. **No rate limiting** - Only daily quotas, no per-minute throttling - ---- - -## Next Steps (Future Work, Not in Plan 1) - -1. **Stdio transport** - Spawn MCP servers as child processes -2. **Distributed quota counters** - Redis/Postgres backend for multi-instance -3. **Response caching** - Optional LLM response memoization -4. **Advanced rate limiting** - Per-minute/hour sliding windows -5. **Webhook notifications** - Alert on quota exceeded -6. **RBAC policies** - Role-based tool access control -7. **Prometheus metrics** - MCP-specific Prometheus exports - ---- - -## File Manifest - -### New Files -``` -mcp/__init__.py -mcp/config.py -mcp/proxy.py -mcp/quota.py -mcp/router.py -audit/__init__.py -audit/sqlite.py -console/__init__.py -console/router.py -console/static/index.html -console/static/app.js -console/static/style.css -attach/cli_mcp.py -attach/cli_claude.py -tests/test_mcp_optin.py -tests/test_mcp_proxy_quota.py -tests/test_console_auth.py -tests/test_openmeter_fallback.py -IMPLEMENTATION_SUMMARY.md -``` - -### Modified Files -``` -attach/gateway.py -attach/__main__.py -middleware/auth.py -middleware/session.py -usage/factory.py -pyproject.toml -README.md -``` - ---- - -## Commands Reference - -### Start Gateway with MCP -```bash -export OIDC_ISSUER=https://your-domain.auth0.com/ -export OIDC_AUD=your-api-identifier -export ATTACH_ENABLE_MCP=true - -attach-gateway --port 8080 -``` - -### Configure MCP Server -```bash -attach-gateway mcp add notion http://localhost:7001/mcp \ - --header "Authorization: env:NOTION_TOKEN" - -attach-gateway mcp enable notion -attach-gateway mcp list -``` - -### Setup Quota Policy -```bash -cat > ~/.attach/mcp_policy.json <