Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions src/openagents/utils/mcp_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging
import os
from contextlib import AsyncExitStack
from typing import Any, Dict, List

from mcp import ClientSession, StdioServerParameters
Expand All @@ -28,6 +29,7 @@ def __init__(self):
self._mcp_clients: Dict[str, Any] = {}
self._mcp_tools: List[AgentTool] = []
self._mcp_sessions: Dict[str, ClientSession] = {}
self._exit_stacks: Dict[str, AsyncExitStack] = {}

async def setup_mcp_clients(self, mcp_configs: List[MCPServerConfig]) -> List[AgentTool]:
"""Setup MCP clients based on configuration.
Expand Down Expand Up @@ -91,25 +93,31 @@ async def _setup_stdio_mcp_client(self, mcp_config: MCPServerConfig):
env=env
)

# Use the stdio client from the MCP library
transport = await stdio_client(server_params).__aenter__()
read_stream, write_stream = transport

# Create a session over those streams
session = ClientSession(read_stream, write_stream)
await session.__aenter__()

# Use AsyncExitStack for proper context manager lifecycle
exit_stack = AsyncExitStack()
await exit_stack.__aenter__()

read_stream, write_stream = await exit_stack.enter_async_context(
stdio_client(server_params)
)

session = await exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)

# Initialize the session
await session.initialize()

logger.info(f"Connected to stdio MCP server '{mcp_config.name}'")

# Store the exit stack for cleanup
self._exit_stacks[mcp_config.name] = exit_stack

# Store the session and transport info
mcp_client = {
"name": mcp_config.name,
"type": "stdio",
"session": session,
"transport": transport,
"config": mcp_config
}

Expand All @@ -120,6 +128,11 @@ async def _setup_stdio_mcp_client(self, mcp_config: MCPServerConfig):
await self._add_mcp_tools(mcp_config.name, session)

except Exception as e:
if mcp_config.name in self._exit_stacks:
try:
await self._exit_stacks.pop(mcp_config.name).__aexit__(None, None, None)
except Exception:
pass
logger.error(f"Failed to start stdio MCP server '{mcp_config.name}': {e}")
raise

Expand All @@ -129,26 +142,32 @@ async def _setup_sse_mcp_client(self, mcp_config: MCPServerConfig):
raise ValueError(f"URL is required for sse MCP server '{mcp_config.name}'")

try:
# Use the SSE client from the MCP library
transport = await sse_client(mcp_config.url).__aenter__()
read_stream, write_stream = transport

# Create a session over those streams
session = ClientSession(read_stream, write_stream)
await session.__aenter__()

# Use AsyncExitStack for proper context manager lifecycle
exit_stack = AsyncExitStack()
await exit_stack.__aenter__()

read_stream, write_stream = await exit_stack.enter_async_context(
sse_client(mcp_config.url)
)

session = await exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)

# Initialize the session
await session.initialize()

logger.info(f"Connected to SSE MCP server '{mcp_config.name}'")

# Store the exit stack for cleanup
self._exit_stacks[mcp_config.name] = exit_stack

# Store the session and transport info
mcp_client = {
"name": mcp_config.name,
"type": "sse",
"url": mcp_config.url,
"session": session,
"transport": transport,
"config": mcp_config
}

Expand All @@ -159,6 +178,11 @@ async def _setup_sse_mcp_client(self, mcp_config: MCPServerConfig):
await self._add_mcp_tools(mcp_config.name, session)

except Exception as e:
if mcp_config.name in self._exit_stacks:
try:
await self._exit_stacks.pop(mcp_config.name).__aexit__(None, None, None)
except Exception:
pass
logger.error(f"Failed to setup SSE MCP server '{mcp_config.name}': {e}")
raise

Expand All @@ -168,27 +192,36 @@ async def _setup_streamable_http_mcp_client(self, mcp_config: MCPServerConfig):
raise ValueError(f"URL is required for streamable_http MCP server '{mcp_config.name}'")

try:
# Use the streamable HTTP client from the MCP library
transport = await streamablehttp_client(mcp_config.url).__aenter__()
read_stream, write_stream, get_session_id = transport

# Create a session over those streams
session = ClientSession(read_stream, write_stream)
await session.__aenter__()

# Use AsyncExitStack to properly manage the nested async context managers.
# streamablehttp_client uses an internal anyio TaskGroup that MUST be
# exited in the same task it was entered — manual __aenter__/__aexit__
# breaks this invariant and causes CancelledError.
exit_stack = AsyncExitStack()
await exit_stack.__aenter__()

read_stream, write_stream, get_session_id = await exit_stack.enter_async_context(
streamablehttp_client(mcp_config.url)
)

session = await exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)

# Initialize the session
await session.initialize()

session_id = get_session_id()
logger.info(f"Connected to streamable HTTP MCP server '{mcp_config.name}' with session ID: {session_id}")

# Store the exit stack so cleanup can close everything properly
self._exit_stacks[mcp_config.name] = exit_stack

# Store the session and transport info
mcp_client = {
"name": mcp_config.name,
"type": "streamable_http",
"url": mcp_config.url,
"session": session,
"transport": transport,
"session_id": session_id,
"config": mcp_config
}
Expand All @@ -200,6 +233,12 @@ async def _setup_streamable_http_mcp_client(self, mcp_config: MCPServerConfig):
await self._add_mcp_tools(mcp_config.name, session)

except Exception as e:
# Clean up the exit stack if setup failed partway through
if mcp_config.name in self._exit_stacks:
try:
await self._exit_stacks.pop(mcp_config.name).__aexit__(None, None, None)
except Exception:
pass
logger.error(f"Failed to setup streamable HTTP MCP server '{mcp_config.name}': {e}")
raise

Expand Down Expand Up @@ -259,16 +298,28 @@ async def cleanup_mcp_clients(self):
"""Cleanup all MCP clients and processes."""
logger.info("Cleaning up MCP clients")

# Cleanup MCP sessions
# Cleanup via exit stacks (handles streamable_http and any future stack-managed clients)
for name, exit_stack in self._exit_stacks.items():
try:
await exit_stack.__aexit__(None, None, None)
logger.debug(f"Closed MCP exit stack: {name}")
except Exception as e:
logger.error(f"Error closing MCP exit stack '{name}': {e}")

# Cleanup any remaining sessions not managed by exit stacks (e.g. stdio, sse)
for name, session in self._mcp_sessions.items():
if name in self._exit_stacks:
continue # Already cleaned up via exit stack
try:
await session.__aexit__(None, None, None)
logger.debug(f"Closed MCP session: {name}")
except Exception as e:
logger.error(f"Error closing MCP session '{name}': {e}")

# Cleanup MCP transport connections
# Cleanup any remaining transport connections not managed by exit stacks
for name, client in self._mcp_clients.items():
if name in self._exit_stacks:
continue # Already cleaned up via exit stack
if client.get("transport") and hasattr(client["transport"], "__aexit__"):
try:
await client["transport"].__aexit__(None, None, None)
Expand All @@ -280,6 +331,7 @@ async def cleanup_mcp_clients(self):
self._mcp_clients.clear()
self._mcp_tools.clear()
self._mcp_sessions.clear()
self._exit_stacks.clear()

def get_mcp_tools(self) -> List[AgentTool]:
"""Get all tools from connected MCP servers."""
Expand Down