diff --git a/src/openagents/utils/mcp_connector.py b/src/openagents/utils/mcp_connector.py index 6f2b68de5..0e7169e59 100644 --- a/src/openagents/utils/mcp_connector.py +++ b/src/openagents/utils/mcp_connector.py @@ -7,6 +7,7 @@ import logging import os +from contextlib import AsyncExitStack from typing import Any, Dict, List from mcp import ClientSession, StdioServerParameters @@ -28,6 +29,7 @@ def __init__(self): self._mcp_clients: Dict[str, Any] = {} self._mcp_tools: List[AgentTool] = [] self._mcp_sessions: Dict[str, ClientSession] = {} + self._exit_stacks: Dict[str, AsyncExitStack] = {} async def setup_mcp_clients(self, mcp_configs: List[MCPServerConfig]) -> List[AgentTool]: """Setup MCP clients based on configuration. @@ -91,25 +93,31 @@ async def _setup_stdio_mcp_client(self, mcp_config: MCPServerConfig): env=env ) - # Use the stdio client from the MCP library - transport = await stdio_client(server_params).__aenter__() - read_stream, write_stream = transport - - # Create a session over those streams - session = ClientSession(read_stream, write_stream) - await session.__aenter__() - + # Use AsyncExitStack for proper context manager lifecycle + exit_stack = AsyncExitStack() + await exit_stack.__aenter__() + + read_stream, write_stream = await exit_stack.enter_async_context( + stdio_client(server_params) + ) + + session = await exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + # Initialize the session await session.initialize() logger.info(f"Connected to stdio MCP server '{mcp_config.name}'") + # Store the exit stack for cleanup + self._exit_stacks[mcp_config.name] = exit_stack + # Store the session and transport info mcp_client = { "name": mcp_config.name, "type": "stdio", "session": session, - "transport": transport, "config": mcp_config } @@ -120,6 +128,11 @@ async def _setup_stdio_mcp_client(self, mcp_config: MCPServerConfig): await self._add_mcp_tools(mcp_config.name, session) except Exception as e: + if mcp_config.name in self._exit_stacks: + try: + await self._exit_stacks.pop(mcp_config.name).__aexit__(None, None, None) + except Exception: + pass logger.error(f"Failed to start stdio MCP server '{mcp_config.name}': {e}") raise @@ -129,26 +142,32 @@ async def _setup_sse_mcp_client(self, mcp_config: MCPServerConfig): raise ValueError(f"URL is required for sse MCP server '{mcp_config.name}'") try: - # Use the SSE client from the MCP library - transport = await sse_client(mcp_config.url).__aenter__() - read_stream, write_stream = transport - - # Create a session over those streams - session = ClientSession(read_stream, write_stream) - await session.__aenter__() - + # Use AsyncExitStack for proper context manager lifecycle + exit_stack = AsyncExitStack() + await exit_stack.__aenter__() + + read_stream, write_stream = await exit_stack.enter_async_context( + sse_client(mcp_config.url) + ) + + session = await exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + # Initialize the session await session.initialize() logger.info(f"Connected to SSE MCP server '{mcp_config.name}'") + # Store the exit stack for cleanup + self._exit_stacks[mcp_config.name] = exit_stack + # Store the session and transport info mcp_client = { "name": mcp_config.name, "type": "sse", "url": mcp_config.url, "session": session, - "transport": transport, "config": mcp_config } @@ -159,6 +178,11 @@ async def _setup_sse_mcp_client(self, mcp_config: MCPServerConfig): await self._add_mcp_tools(mcp_config.name, session) except Exception as e: + if mcp_config.name in self._exit_stacks: + try: + await self._exit_stacks.pop(mcp_config.name).__aexit__(None, None, None) + except Exception: + pass logger.error(f"Failed to setup SSE MCP server '{mcp_config.name}': {e}") raise @@ -168,27 +192,36 @@ async def _setup_streamable_http_mcp_client(self, mcp_config: MCPServerConfig): raise ValueError(f"URL is required for streamable_http MCP server '{mcp_config.name}'") try: - # Use the streamable HTTP client from the MCP library - transport = await streamablehttp_client(mcp_config.url).__aenter__() - read_stream, write_stream, get_session_id = transport - - # Create a session over those streams - session = ClientSession(read_stream, write_stream) - await session.__aenter__() - + # Use AsyncExitStack to properly manage the nested async context managers. + # streamablehttp_client uses an internal anyio TaskGroup that MUST be + # exited in the same task it was entered — manual __aenter__/__aexit__ + # breaks this invariant and causes CancelledError. + exit_stack = AsyncExitStack() + await exit_stack.__aenter__() + + read_stream, write_stream, get_session_id = await exit_stack.enter_async_context( + streamablehttp_client(mcp_config.url) + ) + + session = await exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + # Initialize the session await session.initialize() session_id = get_session_id() logger.info(f"Connected to streamable HTTP MCP server '{mcp_config.name}' with session ID: {session_id}") + # Store the exit stack so cleanup can close everything properly + self._exit_stacks[mcp_config.name] = exit_stack + # Store the session and transport info mcp_client = { "name": mcp_config.name, "type": "streamable_http", "url": mcp_config.url, "session": session, - "transport": transport, "session_id": session_id, "config": mcp_config } @@ -200,6 +233,12 @@ async def _setup_streamable_http_mcp_client(self, mcp_config: MCPServerConfig): await self._add_mcp_tools(mcp_config.name, session) except Exception as e: + # Clean up the exit stack if setup failed partway through + if mcp_config.name in self._exit_stacks: + try: + await self._exit_stacks.pop(mcp_config.name).__aexit__(None, None, None) + except Exception: + pass logger.error(f"Failed to setup streamable HTTP MCP server '{mcp_config.name}': {e}") raise @@ -259,16 +298,28 @@ async def cleanup_mcp_clients(self): """Cleanup all MCP clients and processes.""" logger.info("Cleaning up MCP clients") - # Cleanup MCP sessions + # Cleanup via exit stacks (handles streamable_http and any future stack-managed clients) + for name, exit_stack in self._exit_stacks.items(): + try: + await exit_stack.__aexit__(None, None, None) + logger.debug(f"Closed MCP exit stack: {name}") + except Exception as e: + logger.error(f"Error closing MCP exit stack '{name}': {e}") + + # Cleanup any remaining sessions not managed by exit stacks (e.g. stdio, sse) for name, session in self._mcp_sessions.items(): + if name in self._exit_stacks: + continue # Already cleaned up via exit stack try: await session.__aexit__(None, None, None) logger.debug(f"Closed MCP session: {name}") except Exception as e: logger.error(f"Error closing MCP session '{name}': {e}") - # Cleanup MCP transport connections + # Cleanup any remaining transport connections not managed by exit stacks for name, client in self._mcp_clients.items(): + if name in self._exit_stacks: + continue # Already cleaned up via exit stack if client.get("transport") and hasattr(client["transport"], "__aexit__"): try: await client["transport"].__aexit__(None, None, None) @@ -280,6 +331,7 @@ async def cleanup_mcp_clients(self): self._mcp_clients.clear() self._mcp_tools.clear() self._mcp_sessions.clear() + self._exit_stacks.clear() def get_mcp_tools(self) -> List[AgentTool]: """Get all tools from connected MCP servers."""