diff --git a/verifiers/mcp/mcp_server_connection.py b/verifiers/mcp/mcp_server_connection.py new file mode 100644 index 000000000..675947651 --- /dev/null +++ b/verifiers/mcp/mcp_server_connection.py @@ -0,0 +1,99 @@ +import asyncio +import logging +from typing import Dict, Optional + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.types import TextContent, Tool + +from .models import MCPServerConfig + + +class MCPServerConnection: + def __init__(self, config: MCPServerConfig, logger: logging.Logger): + self.config = config + self.logger = logger + self.session: Optional[ClientSession] = None + self.tools: Dict[str, Tool] = {} + + self._connection_task: Optional[asyncio.Task] = None + self._ready = asyncio.Event() + self._error: Optional[Exception] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + + async def connect(self): + # Record the loop this connection is bound to + self.loop = asyncio.get_running_loop() + self._connection_task = asyncio.create_task(self._get_connection()) + + await self._ready.wait() + + if self._error: + raise self._error + + return self.tools + + async def _get_connection(self): + try: + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + self.session = session + + await session.initialize() + + tools_response = await session.list_tools() + + for tool in tools_response.tools: + self.tools[tool.name] = tool + + self._ready.set() + + while True: + await asyncio.sleep(1) + + except asyncio.CancelledError: + raise + except Exception as e: + self._error = e + self._ready.set() + finally: + self.session = None + self.tools = {} + + async def call_tool(self, tool_name: str, arguments: dict) -> str: + assert self.session is not None, f"Server '{self.config.name}' not connected" + assert self.loop is not None, "Connection loop not initialized" + fut = asyncio.run_coroutine_threadsafe( + self.session.call_tool(tool_name, arguments=arguments), self.loop + ) + result = await asyncio.wrap_future(fut) + + if result.content: + text_parts = [] + for content_item in result.content: + if hasattr(content_item, "text"): + assert isinstance(content_item, TextContent) + text_parts.append(content_item.text) + elif hasattr(content_item, "type") and content_item.type == "text": + text_parts.append(getattr(content_item, "text", str(content_item))) + else: + text_parts.append(str(content_item)) + + return "\n".join(text_parts) + + return "No result returned from tool" + + async def disconnect(self): + assert self._connection_task is not None + self._connection_task.cancel() + try: + await self._connection_task + except asyncio.CancelledError: + pass + self.logger.info(f"MCP server '{self.config.name}' terminated") diff --git a/verifiers/mcp/mcp_tool_wrapper.py b/verifiers/mcp/mcp_tool_wrapper.py new file mode 100644 index 000000000..d98ff4cea --- /dev/null +++ b/verifiers/mcp/mcp_tool_wrapper.py @@ -0,0 +1,59 @@ +from typing import Any + +from mcp.types import Tool + +from .mcp_server_connection import MCPServerConnection + + +class MCPToolWrapper: + def __init__( + self, server_name: str, tool: Tool, server_connection: MCPServerConnection + ): + self.server_name = server_name + self.tool = tool + self.server_connection = server_connection + + self.__name__ = tool.name + self.__doc__ = tool.description or "" + + self.__annotations__ = self._build_annotations() + + def _build_annotations(self) -> dict: + annotations = {} + + if self.tool.inputSchema: + properties = self.tool.inputSchema.get("properties", {}) + + for param_name, param_spec in properties.items(): + param_type = param_spec.get("type", "string") + if param_type == "string": + annotations[param_name] = str + elif param_type == "integer": + annotations[param_name] = int + elif param_type == "number": + annotations[param_name] = float + elif param_type == "boolean": + annotations[param_name] = bool + elif param_type == "array": + annotations[param_name] = list + elif param_type == "object": + annotations[param_name] = dict + else: + annotations[param_name] = Any + + annotations["return"] = str + return annotations + + async def __call__(self, **kwargs): + return await self.server_connection.call_tool(self.tool.name, kwargs) + + def to_oai_tool(self) -> dict: + return { + "type": "function", + "function": { + "name": self.__name__, + "description": self.__doc__ or "", + "parameters": self.tool.inputSchema + or {"type": "object", "properties": {}}, + }, + } diff --git a/verifiers/mcp/models.py b/verifiers/mcp/models.py new file mode 100644 index 000000000..7a20dd38e --- /dev/null +++ b/verifiers/mcp/models.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class MCPServerConfig: + name: str + command: str + args: List[str] | None = None + env: Dict[str, str] | None = None + description: str = ""