diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index 7ae969a..5790f99 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -176,6 +176,7 @@ def __init__( | list[StdioServerParameters | dict[str, Any]], adapter: ToolAdapter, connect_timeout: int = 30, + client_session_timeout_seconds: float | timedelta | None = 5, ): """ Manage the MCP server / client lifecycle and expose tools adapted with the adapter. @@ -185,6 +186,7 @@ def __init__( MCP server parameters (stdio or sse). Can be a list if you want to connect multiple MCPs at once. adapter (ToolAdapter): Adapter to use to convert MCP tools call into agentic framework tools. connect_timeout (int): Connection timeout in seconds to the mcp server (default is 30s). + client_session_timeout_seconds: Timeout for MCP ClientSession calls Raises: TimeoutError: When the connection to the mcp server time out. @@ -209,6 +211,7 @@ def __init__( self.thread = threading.Thread(target=self._run_loop, daemon=True) self.connect_timeout = connect_timeout + self.client_session_timeout_seconds = client_session_timeout_seconds def _run_loop(self): """Runs the event loop in a separate thread (for synchronous usage).""" @@ -217,7 +220,9 @@ def _run_loop(self): async def setup(): async with AsyncExitStack() as stack: connections = [ - await stack.enter_async_context(mcptools(params)) + await stack.enter_async_context( + mcptools(params, self.client_session_timeout_seconds) + ) for params in self.serverparams ] self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] @@ -323,7 +328,9 @@ async def __aenter__(self) -> list[Any]: self._ctxmanager = AsyncExitStack() connections = [ - await self._ctxmanager.enter_async_context(mcptools(params)) + await self._ctxmanager.enter_async_context( + mcptools(params, self.client_session_timeout_seconds) + ) for params in self.serverparams ] diff --git a/tests/test_core.py b/tests/test_core.py index 98ba851..f66170e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -153,6 +153,28 @@ async def echo_streamable_http_server(echo_server_streamable_http_script): process.wait() +@pytest.fixture +def slow_start_server_script(): + return dedent( + ''' + import time + from mcp.server.fastmcp import FastMCP + + # Sleep for 2 seconds to simulate slow startup + time.sleep(2) + + mcp = FastMCP("Slow Server") + + @mcp.tool() + def echo_tool(text: str) -> str: + """Echo the input text""" + return f"Echo: {text}" + + mcp.run() + ''' + ) + + def test_basic_sync(echo_server_script): with MCPAdapt( StdioServerParameters( @@ -319,3 +341,63 @@ async def test_basic_async_streamable_http(echo_streamable_http_server): ) as tools: assert len(tools) == 1 assert (await tools[0]({"text": "hello"})).content[0].text == "Echo: hello" + + +def test_connect_timeout(slow_start_server_script): + """Test that connect_timeout raises TimeoutError when server starts slowly""" + with pytest.raises( + TimeoutError, match="Couldn't connect to the MCP server after 1 seconds" + ): + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", slow_start_server_script] + ), + DummyAdapter(), + connect_timeout=1, # 1 second timeout, server takes 2 seconds to start + ): + pass + + +def test_client_session_timeout_parameter_propagation(echo_server_script): + """Test that client_session_timeout_seconds parameter is properly stored and accessible""" + from datetime import timedelta + + # Test with float value + adapter_float = MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + DummyAdapter(), + client_session_timeout_seconds=2.5, + ) + assert adapter_float.client_session_timeout_seconds == 2.5 + + # Test with timedelta value + timeout_td = timedelta(seconds=3.0) + adapter_td = MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + DummyAdapter(), + client_session_timeout_seconds=timeout_td, + ) + assert adapter_td.client_session_timeout_seconds == timeout_td + + # Test with None value + adapter_none = MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + DummyAdapter(), + client_session_timeout_seconds=None, + ) + assert adapter_none.client_session_timeout_seconds is None + + # Test default value + adapter_default = MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + DummyAdapter(), + ) + assert adapter_default.client_session_timeout_seconds == 5