From 90e872076825dfc7f3931e61a4e68d90e94a6933 Mon Sep 17 00:00:00 2001 From: Jerome Swannack Date: Mon, 19 Jan 2026 16:03:15 +0000 Subject: [PATCH 1/3] feat: add message middleware support for ClientSession and ServerSession Add a middleware pattern that allows transforming JSON-RPC messages before sending and after receiving. This provides a clean way to extend protocol messages (e.g., adding custom capabilities to initialize requests) without needing to subclass or override session methods. Middleware functions receive a JSONRPCMessage and return a (possibly transformed) JSONRPCMessage. Both sync and async middleware are supported. Usage example: def add_capabilities(message: JSONRPCMessage) -> JSONRPCMessage: if isinstance(message.root, JSONRPCRequest): # Transform the message... pass return message session = ClientSession( read_stream, write_stream, send_middleware=[add_capabilities], ) Changes: - Add MessageMiddleware type alias in mcp.shared.session - Add send_middleware and receive_middleware parameters to BaseSession - Apply middleware in send_request, send_notification, _send_response - Apply middleware in _receive_loop after receiving messages - Export MessageMiddleware, JSONRPCMessage, JSONRPCNotification from mcp - Add tests for sync and async middleware --- src/mcp/__init__.py | 6 ++ src/mcp/client/session.py | 6 +- src/mcp/server/session.py | 13 +++- src/mcp/shared/session.py | 50 +++++++++++-- tests/client/test_session.py | 141 +++++++++++++++++++++++++++++++++++ 5 files changed, 208 insertions(+), 8 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index fbec40d0a9..cb0213fe28 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -4,6 +4,7 @@ from .server.session import ServerSession from .server.stdio import stdio_server from .shared.exceptions import McpError, UrlElicitationRequiredError +from .shared.session import MessageMiddleware from .types import ( CallToolRequest, ClientCapabilities, @@ -23,6 +24,8 @@ InitializeRequest, InitializeResult, JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, ListPromptsRequest, @@ -87,8 +90,11 @@ "InitializeResult", "InitializedNotification", "JSONRPCError", + "JSONRPCMessage", + "JSONRPCNotification", "JSONRPCRequest", "JSONRPCResponse", + "MessageMiddleware", "ListPromptsRequest", "ListPromptsResult", "ListResourcesRequest", diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8519f15cec..b2fa87cd87 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -12,7 +12,7 @@ from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import BaseSession, MessageMiddleware, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -123,6 +123,8 @@ def __init__( *, sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, + send_middleware: list["MessageMiddleware"] | None = None, + receive_middleware: list["MessageMiddleware"] | None = None, ) -> None: super().__init__( read_stream, @@ -130,6 +132,8 @@ def __init__( types.ServerRequest, types.ServerNotification, read_timeout_seconds=read_timeout_seconds, + send_middleware=send_middleware, + receive_middleware=receive_middleware, ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 8f0baa3e9c..d8e3243ad6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -54,6 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + MessageMiddleware, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -91,8 +92,18 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + *, + send_middleware: list["MessageMiddleware"] | None = None, + receive_middleware: list["MessageMiddleware"] | None = None, ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + super().__init__( + read_stream, + write_stream, + types.ClientRequest, + types.ClientNotification, + send_middleware=send_middleware, + receive_middleware=receive_middleware, + ) self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized ) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3033acd0eb..169b1fde7a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType @@ -43,6 +43,10 @@ RequestId = str | int +# Middleware type for transforming messages before sending or after receiving. +# Can be sync (returns JSONRPCMessage) or async (returns Awaitable[JSONRPCMessage]). +MessageMiddleware = Callable[[JSONRPCMessage], JSONRPCMessage | Awaitable[JSONRPCMessage]] + class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" @@ -190,6 +194,9 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + *, + send_middleware: list[MessageMiddleware] | None = None, + receive_middleware: list[MessageMiddleware] | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -202,6 +209,22 @@ def __init__( self._progress_callbacks = {} self._response_routers = [] self._exit_stack = AsyncExitStack() + self._send_middleware = send_middleware or [] + self._receive_middleware = receive_middleware or [] + + async def _apply_middleware( + self, message: JSONRPCMessage, middleware_list: list[MessageMiddleware] + ) -> JSONRPCMessage: + """Apply a list of middleware functions to a message.""" + import inspect + + for middleware in middleware_list: + result = middleware(message) + if inspect.isawaitable(result): + message = await result + else: + message = result # type: ignore[assignment] + return message def add_response_router(self, router: ResponseRouter) -> None: """ @@ -278,7 +301,9 @@ async def send_request( **request_data, ) - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + message = JSONRPCMessage(jsonrpc_request) + message = await self._apply_middleware(message, self._send_middleware) + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) # request read timeout takes precedence over session read timeout timeout = None @@ -328,8 +353,10 @@ async def send_notification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) + message = JSONRPCMessage(jsonrpc_notification) + message = await self._apply_middleware(message, self._send_middleware) session_message = SessionMessage( # pragma: no cover - message=JSONRPCMessage(jsonrpc_notification), + message=message, metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) await self._write_stream.send(session_message) @@ -337,7 +364,9 @@ async def send_notification( async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + message = JSONRPCMessage(jsonrpc_error) + message = await self._apply_middleware(message, self._send_middleware) + session_message = SessionMessage(message=message) await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( @@ -345,7 +374,9 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er id=request_id, result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + message = JSONRPCMessage(jsonrpc_response) + message = await self._apply_middleware(message, self._send_middleware) + session_message = SessionMessage(message=message) await self._write_stream.send(session_message) async def _receive_loop(self) -> None: @@ -357,7 +388,14 @@ async def _receive_loop(self) -> None: async for message in self._read_stream: if isinstance(message, Exception): # pragma: no cover await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): + continue + + # Apply receive middleware to transform the message + if self._receive_middleware: + transformed_msg = await self._apply_middleware(message.message, self._receive_middleware) + message = SessionMessage(message=transformed_msg, metadata=message.metadata) # noqa: PLW2901 + + if isinstance(message.message.root, JSONRPCRequest): try: validated_request = self._receive_request_type.model_validate( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index eb2683fbdb..8c1c774d2a 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -768,3 +768,144 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_client_session_send_middleware(): + """Test that send middleware can transform outgoing messages.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_request = None + middleware_called = False + + def add_custom_field(message: JSONRPCMessage) -> JSONRPCMessage: + """Middleware that adds a custom field to initialize request params.""" + nonlocal middleware_called + middleware_called = True + + if isinstance(message.root, JSONRPCRequest): + # Add custom extension to the capabilities + data = message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + if data.get("method") == "initialize" and "params" in data: + if "capabilities" not in data["params"]: + data["params"]["capabilities"] = {} + # Add a custom extension field + data["params"]["capabilities"]["customExtension"] = {"enabled": True} + return JSONRPCMessage(JSONRPCRequest(**data)) + return message + + async def mock_server(): + nonlocal received_request + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + received_request = jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + send_middleware=[add_custom_field], + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Verify middleware was called and transformed the request + assert middleware_called + assert received_request is not None + assert "params" in received_request + assert "capabilities" in received_request["params"] + assert "customExtension" in received_request["params"]["capabilities"] + assert received_request["params"]["capabilities"]["customExtension"] == {"enabled": True} + + +@pytest.mark.anyio +async def test_client_session_async_middleware(): + """Test that async middleware works correctly.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + middleware_called = False + + async def async_middleware(message: JSONRPCMessage) -> JSONRPCMessage: + """Async middleware that just passes through.""" + nonlocal middleware_called + middleware_called = True + # Simulate some async work + await anyio.sleep(0) + return message + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + send_middleware=[async_middleware], + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + assert middleware_called From ac971959ceb3799a36dcf6bdefe62b18223c1111 Mon Sep 17 00:00:00 2001 From: Jerome Swannack Date: Mon, 19 Jan 2026 16:10:57 +0000 Subject: [PATCH 2/3] fix: address review feedback and add receive middleware test - Move 'import inspect' to top of file - Pre-compute whether middleware is async using inspect.iscoroutinefunction() instead of checking on every message - Add test for receive_middleware to fix coverage --- src/mcp/shared/session.py | 25 +++++++------- tests/client/test_session.py | 64 ++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 169b1fde7a..e75b1c42b5 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,3 +1,4 @@ +import inspect import logging from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack @@ -209,22 +210,24 @@ def __init__( self._progress_callbacks = {} self._response_routers = [] self._exit_stack = AsyncExitStack() - self._send_middleware = send_middleware or [] - self._receive_middleware = receive_middleware or [] + # Pre-compute whether each middleware is async to avoid checking on every message + self._send_middleware: list[tuple[MessageMiddleware, bool]] = [ + (m, inspect.iscoroutinefunction(m)) for m in (send_middleware or []) + ] + self._receive_middleware: list[tuple[MessageMiddleware, bool]] = [ + (m, inspect.iscoroutinefunction(m)) for m in (receive_middleware or []) + ] async def _apply_middleware( - self, message: JSONRPCMessage, middleware_list: list[MessageMiddleware] + self, message: JSONRPCMessage, middleware_list: list[tuple[MessageMiddleware, bool]] ) -> JSONRPCMessage: """Apply a list of middleware functions to a message.""" - import inspect - - for middleware in middleware_list: + for middleware, is_async in middleware_list: result = middleware(message) - if inspect.isawaitable(result): - message = await result - else: - message = result # type: ignore[assignment] - return message + if is_async: + result = await result # type: ignore[misc] + message = result # type: ignore[assignment] + return message # type: ignore[return-value] def add_response_router(self, router: ResponseRouter) -> None: """ diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8c1c774d2a..b9af6490dc 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -909,3 +909,67 @@ async def mock_server(): await session.initialize() assert middleware_called + + +@pytest.mark.anyio +async def test_client_session_receive_middleware(): + """Test that receive middleware can transform incoming messages.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + middleware_called = False + received_response = None + + def receive_transform(message: JSONRPCMessage) -> JSONRPCMessage: + """Middleware that observes incoming messages.""" + nonlocal middleware_called, received_response + middleware_called = True + if isinstance(message.root, JSONRPCResponse): + received_response = message.root + return message + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + receive_middleware=[receive_transform], + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Verify receive middleware was called and saw the response + assert middleware_called + assert received_response is not None From 0f3edd4c5cc3fc58ee7bb106359a1619c46a622d Mon Sep 17 00:00:00 2001 From: Jerome Swannack Date: Mon, 19 Jan 2026 16:22:57 +0000 Subject: [PATCH 3/3] fix: add pragma comments for test branch coverage --- tests/client/test_session.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index b9af6490dc..dbc59f7a11 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -784,16 +784,16 @@ def add_custom_field(message: JSONRPCMessage) -> JSONRPCMessage: nonlocal middleware_called middleware_called = True - if isinstance(message.root, JSONRPCRequest): + if isinstance(message.root, JSONRPCRequest): # pragma: no branch # Add custom extension to the capabilities data = message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - if data.get("method") == "initialize" and "params" in data: - if "capabilities" not in data["params"]: + if data.get("method") == "initialize" and "params" in data: # pragma: no branch + if "capabilities" not in data["params"]: # pragma: no cover data["params"]["capabilities"] = {} # Add a custom extension field data["params"]["capabilities"]["customExtension"] = {"enabled": True} return JSONRPCMessage(JSONRPCRequest(**data)) - return message + return message # pragma: no cover async def mock_server(): nonlocal received_request @@ -924,7 +924,7 @@ def receive_transform(message: JSONRPCMessage) -> JSONRPCMessage: """Middleware that observes incoming messages.""" nonlocal middleware_called, received_response middleware_called = True - if isinstance(message.root, JSONRPCResponse): + if isinstance(message.root, JSONRPCResponse): # pragma: no branch received_response = message.root return message