diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7151d57cd..d6ce1d8c5 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,4 +1,5 @@ import logging +from types import TracebackType from typing import Any, Protocol import anyio.lowlevel @@ -107,6 +108,8 @@ class ClientSession( types.ServerNotification, ] ): + _entered: bool + def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -133,10 +136,29 @@ def __init__( self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None self._experimental_features: ExperimentalClientFeatures | None = None + self._entered = False # Experimental: Task handlers (use defaults if not provided) self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() + async def __aenter__(self) -> "ClientSession": + self._entered = True + await super().__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._entered = False + await super().__aexit__(exc_type, exc_value, traceback) + + def _check_is_active(self) -> None: + if not self._entered: + raise RuntimeError("ClientSession must be used within an 'async with' block.") + @property def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: return types.server_request_adapter @@ -146,6 +168,7 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification] return types.server_notification_adapter async def initialize(self) -> types.InitializeResult: + self._check_is_active() sampling = ( (self._sampling_capabilities or types.SamplingCapability()) if self._sampling_callback is not _default_sampling_callback diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 5c1f55d23..51b38a270 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import anyio import pytest @@ -705,3 +705,18 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_initialize_without_context_manager_raises_error(): + """Test that calling initialize() without entering the context manager raises RuntimeError.""" + send_stream, receive_stream = anyio.create_memory_object_stream[Any](0) + + read_stream = cast(Any, receive_stream) + write_stream = cast(Any, send_stream) + + async with send_stream, receive_stream: + session = ClientSession(read_stream, write_stream) + + with pytest.raises(RuntimeError, match="must be used within"): + await session.initialize()