-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware.py
More file actions
132 lines (102 loc) · 4.34 KB
/
middleware.py
File metadata and controls
132 lines (102 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Custom middleware for MCP Search server."""
import logging
import time
from typing import Optional
from fastmcp.server.dependencies import get_http_headers
from fastmcp.server.middleware import Middleware, MiddlewareContext
from mcp import McpError
from mcp.types import ErrorData
logger = logging.getLogger("mcp_search")
class BearerAuthMiddleware(Middleware):
"""
Middleware that authenticates requests using Bearer token (API key).
The API key should be passed in the Authorization header as:
Authorization: Bearer <api_key>
"""
def __init__(self, api_key: str):
"""
Initialize the authentication middleware.
Args:
api_key: The expected API key for authentication.
"""
if not api_key:
raise ValueError("API key cannot be empty")
self.api_key = api_key
async def on_request(self, context: MiddlewareContext, call_next):
"""Authenticate all incoming requests."""
# Skip authentication for initialization (handshake)
if context.method == "initialize":
return await call_next(context)
# Get authorization header
auth_header = self._get_auth_header()
if not auth_header:
logger.warning(f"✗ Authentication failed for {context.method}: Missing Authorization header")
raise McpError(
ErrorData(
code=-32001,
message="Authentication required: Missing Authorization header"
)
)
# Validate Bearer token format
if not auth_header.startswith("Bearer "):
logger.warning(f"✗ Authentication failed for {context.method}: Invalid Authorization header format")
raise McpError(
ErrorData(
code=-32001,
message="Authentication required: Invalid Authorization header format. Expected 'Bearer <token>'"
)
)
# Extract and validate token
token = auth_header[7:] # Remove "Bearer " prefix
if token != self.api_key:
logger.warning(f"✗ Authentication failed for {context.method}: Invalid API key")
raise McpError(
ErrorData(
code=-32001,
message="Authentication failed: Invalid API key"
)
)
logger.debug(f"✓ Authentication successful for {context.method}")
return await call_next(context)
def _get_auth_header(self) -> Optional[str]:
"""Extract Authorization header from HTTP request."""
try:
headers = get_http_headers(include_all=True)
print(headers)
if headers:
return headers.get("authorization") or headers.get("Authorization")
except Exception:
# Not in HTTP context or headers not available
pass
return None
class RequestLoggingMiddleware(Middleware):
"""
Middleware that logs all MCP requests and responses.
Logs request method, timing, and success/failure status in a clean format.
"""
async def on_message(self, context: MiddlewareContext, call_next):
"""Log all MCP messages with timing information."""
start_time = time.perf_counter()
method = context.method
try:
result = await call_next(context)
duration_ms = (time.perf_counter() - start_time) * 1000
logger.info(f"✓ {method} [{duration_ms:.1f}ms]")
return result
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
logger.error(f"✗ {method} [{duration_ms:.1f}ms] - {type(e).__name__}: {e}")
raise
async def on_call_tool(self, context: MiddlewareContext, call_next):
"""Log tool calls with tool name."""
start_time = time.perf_counter()
tool_name = context.message.name
try:
result = await call_next(context)
duration_ms = (time.perf_counter() - start_time) * 1000
logger.info(f"✓ tool/{tool_name} [{duration_ms:.1f}ms]")
return result
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
logger.error(f"✗ tool/{tool_name} [{duration_ms:.1f}ms] - {type(e).__name__}: {e}")
raise