diff --git a/examples/run_async_agent_api_model_with_mcp.py b/examples/run_async_agent_api_model_with_mcp.py new file mode 100644 index 0000000..ff07f21 --- /dev/null +++ b/examples/run_async_agent_api_model_with_mcp.py @@ -0,0 +1,59 @@ +import asyncio +import json + +from lagent.actions import AsyncActionExecutor, AsyncMCPClient +from lagent.agents import AsyncAgentForInternLM +from lagent.agents.aggregator import InternLMToolAggregator +from lagent.agents.stream import get_plugin_prompt +from lagent.llms import AsyncGPTAPI +from lagent.prompts import PluginParser +from lagent.schema import AgentMessage + +TEMPLATE = ( + "You have access to the following tools:\n{tool_description}\nPlease provide" + " your thought process when you need to use a tool, followed by the call statement in this format:" + "\n{invocation_format}" +) +llm = dict(type=AsyncGPTAPI, model_type=None, retry=50, key=None, top_p=0.95, temperature=0.6, max_new_tokens=16384) +plugin = dict( + type=AsyncMCPClient, + name='PlayWright', + server_type='stdio', + command='npx', + args=["@playwright/mcp@latest", '--isolated', '--no-sandbox'], +) +agent = AsyncAgentForInternLM( + llm, + plugin, + template=TEMPLATE.format( + tool_description=get_plugin_prompt(plugin), + invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n', + ), + output_format=PluginParser(begin="```json\n", end="\n```\n", validate=lambda x: json.loads(x.rstrip('`'))), + aggregator=InternLMToolAggregator(environment_role='system'), +) +msg = AgentMessage( + sender='user', + content='解释一下MCP中Sampling Flow的工作机制,参考https://modelcontextprotocol.io/docs/concepts/sampling', +) + +# proj_dir = os.path.dirname(os.path.dirname(__file__)) +# executor = AsyncActionExecutor( +# dict( +# type=AsyncMCPClient, +# name='FS', +# server_type='stdio', +# command='npx', +# args=['-y', '@modelcontextprotocol/server-filesystem', os.path.join(proj_dir, 'docs')], +# ) +# ) +# msg = AgentMessage( +# sender='assistant', +# content=dict( +# name='FS.read_file', +# parameters=dict(path=os.path.join(proj_dir, 'docs/en/get_started/install.md')), +# ), +# ) +loop = asyncio.get_event_loop() +res = loop.run_until_complete(agent(msg)) +print(res.content) diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index b75a226..0398d82 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -8,6 +8,7 @@ from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter from .ipython_manager import IPythonInteractiveManager +from .mcp_client import AsyncMCPClient from .parser import BaseParser, JsonParser, TupleParser from .ppt import PPT, AsyncPPT from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter @@ -39,6 +40,7 @@ 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', + 'AsyncMCPClient', 'BaseParser', 'JsonParser', 'TupleParser', diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index b42036a..ed31b22 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -340,6 +340,8 @@ def sub(self, a, b): action = Calculator() """ + is_stateful = False + def __init__( self, description: Optional[dict] = None, diff --git a/lagent/actions/ipython_interpreter.py b/lagent/actions/ipython_interpreter.py index 68e9a0d..a022e4a 100644 --- a/lagent/actions/ipython_interpreter.py +++ b/lagent/actions/ipython_interpreter.py @@ -52,9 +52,7 @@ async def async_run_code( assert iopub_timeout > interrupt_after try: - async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, - *, - timeout=None): + async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, *, timeout=None): loop = asyncio.get_running_loop() dead_fut = loop.create_future() @@ -71,8 +69,7 @@ def dead(): km.add_restart_callback(restarting, "restart") km.add_restart_callback(dead, "dead") try: - done, _ = await asyncio.wait( - [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) + done, _ = await asyncio.wait([dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) if dead_fut in done: raise KernelDeath() assert msg_task in done @@ -88,13 +85,21 @@ async def send_interrupt(): await km.interrupt_kernel() @retry( - retry=retry_if_result(lambda ret: ret[-1].strip() in [ - 'KeyboardInterrupt', - f"Kernel didn't respond in {wait_for_ready_timeout} seconds", - ] if isinstance(ret, tuple) else False), + retry=retry_if_result( + lambda ret: ( + ret[-1].strip() + in [ + 'KeyboardInterrupt', + f"Kernel didn't respond in {wait_for_ready_timeout} seconds", + ] + if isinstance(ret, tuple) + else False + ) + ), stop=stop_after_attempt(3), wait=wait_fixed(1), - retry_error_callback=lambda state: state.outcome.result()) + retry_error_callback=lambda state: state.outcome.result(), + ) async def run(): execute_result = None error_traceback = None @@ -106,11 +111,9 @@ async def run(): await kc.wait_for_ready(timeout=wait_for_ready_timeout) msg_id = kc.execute(code) while True: - message = await get_iopub_msg_with_death_detection( - kc, timeout=iopub_timeout) + message = await get_iopub_msg_with_death_detection(kc, timeout=iopub_timeout) if logger.isEnabledFor(logging.DEBUG): - logger.debug( - json.dumps(message, indent=2, default=str)) + logger.debug(json.dumps(message, indent=2, default=str)) assert message["parent_header"]["msg_id"] == msg_id msg_type = message["msg_type"] if msg_type == "status": @@ -136,8 +139,7 @@ async def run(): if interrupt_after: run_task = asyncio.create_task(run()) send_interrupt_task = asyncio.create_task(send_interrupt()) - done, _ = await asyncio.wait([run_task, send_interrupt_task], - return_when=asyncio.FIRST_COMPLETED) + done, _ = await asyncio.wait([run_task, send_interrupt_task], return_when=asyncio.FIRST_COMPLETED) if run_task in done: send_interrupt_task.cancel() else: @@ -216,13 +218,10 @@ def reset(self): if not self._initialized: self.initialize() else: - code = "get_ipython().run_line_magic('reset', '-f')\n" + \ - START_CODE.format(self.user_data_dir) + code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir) self._call(code, None) - def _call(self, - command: str, - timeout: Optional[int] = None) -> Tuple[str, bool]: + def _call(self, command: str, timeout: Optional[int] = None) -> Tuple[str, bool]: self.initialize() command = extract_code(command) @@ -261,16 +260,14 @@ def _inner_call(): text = msg['content']['data'].get('text/plain', '') if 'image/png' in msg['content']['data']: image_b64 = msg['content']['data']['image/png'] - image_url = publish_image_to_local( - image_b64, self.work_dir) + image_url = publish_image_to_local(image_b64, self.work_dir) image_idx += 1 image = '![fig-%03d](%s)' % (image_idx, image_url) elif msg_type == 'display_data': if 'image/png' in msg['content']['data']: image_b64 = msg['content']['data']['image/png'] - image_url = publish_image_to_local( - image_b64, self.work_dir) + image_url = publish_image_to_local(image_b64, self.work_dir) image_idx += 1 image = '![fig-%03d](%s)' % (image_idx, image_url) @@ -281,8 +278,7 @@ def _inner_call(): text = msg['content']['text'] elif msg_type == 'error': succeed = False - text = escape_ansi('\n'.join( - msg['content']['traceback'])) + text = escape_ansi('\n'.join(msg['content']['traceback'])) if 'M6_CODE_INTERPRETER_TIMEOUT' in text: text = f'Timeout. No response after {timeout} seconds.' # noqa except queue.Empty: @@ -349,8 +345,7 @@ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: # text=result['text'], image=result.get('image', [])[0]) tool_return.state = ActionStatusCode.SUCCESS else: - tool_return.errmsg = result.get('text', '') if isinstance( - result, dict) else result + tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result tool_return.state = ActionStatusCode.API_ERROR return tool_return @@ -371,6 +366,7 @@ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): action's inputs and outputs. Defaults to :class:`JsonParser`. """ + is_stateful = True _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() def __init__( @@ -390,8 +386,7 @@ def __init__( c = Config() c.KernelManager.transport = 'ipc' - self._amkm = AsyncMultiKernelManager( - config=c, connection_dir=connection_dir) + self._amkm = AsyncMultiKernelManager(config=c, connection_dir=connection_dir) self._max_kernels = max_kernels self._reuse_kernel = reuse_kernel self._sem = asyncio.Semaphore(startup_rate) @@ -403,25 +398,23 @@ async def initialize(self, session_id: str): if session_id in self._KERNEL_CLIENTS: return self._KERNEL_CLIENTS[session_id] if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): - self._KERNEL_CLIENTS[ - session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() + self._KERNEL_CLIENTS[session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() return self._KERNEL_CLIENTS[session_id] async with self._sem: - if self._max_kernels is None or len( - self._KERNEL_CLIENTS - ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: + if ( + self._max_kernels is None + or len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels + ): kernel_id = None try: kernel_id = await self._amkm.start_kernel() kernel = self._amkm.get_kernel(kernel_id) client = kernel.client() _, error_stacktrace, stream_text = await async_run_code( - kernel, - START_CODE.format(self.user_data_dir), - shutdown_kernel=False) + kernel, START_CODE.format(self.user_data_dir), shutdown_kernel=False + ) # check if the output of START_CODE meets expectations - if not (error_stacktrace is None - and stream_text == ''): + if not (error_stacktrace is None and stream_text == ''): raise RuntimeError except Exception as e: print(f'Starting kernel error: {e}') @@ -431,15 +424,11 @@ async def initialize(self, session_id: str): await asyncio.sleep(1) continue if self._max_kernels is None: - self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, - client) + self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client) return kernel_id, kernel, client async with self._lock: - if len(self._KERNEL_CLIENTS - ) + self._UNBOUND_KERNEL_CLIENTS.qsize( - ) < self._max_kernels: - self._KERNEL_CLIENTS[session_id] = (kernel_id, - kernel, client) + if len(self._KERNEL_CLIENTS) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: + self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, client) return kernel_id, kernel, client await self._amkm.shutdown_kernel(kernel_id) self._amkm.remove_kernel(kernel_id) @@ -450,8 +439,7 @@ async def reset(self, session_id: str): if session_id not in self._KERNEL_CLIENTS: return _, kernel, _ = self._KERNEL_CLIENTS[session_id] - code = "get_ipython().run_line_magic('reset', '-f')\n" + \ - START_CODE.format(self.user_data_dir) + code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE.format(self.user_data_dir) await async_run_code(kernel, code, shutdown_kernel=False) async def shutdown(self, session_id: str): @@ -467,18 +455,15 @@ async def close_session(self, session_id: str): if self._reuse_kernel: if session_id in self._KERNEL_CLIENTS: await self.reset(session_id) - await self._UNBOUND_KERNEL_CLIENTS.put( - self._KERNEL_CLIENTS.pop(session_id)) + await self._UNBOUND_KERNEL_CLIENTS.put(self._KERNEL_CLIENTS.pop(session_id)) else: await self.shutdown(session_id) async def _call(self, command, timeout=None, session_id=None): _, kernel, _ = await self.initialize(str(session_id)) result = await async_run_code( - kernel, - extract_code(command), - interrupt_after=timeout or self.timeout, - shutdown_kernel=False) + kernel, extract_code(command), interrupt_after=timeout or self.timeout, shutdown_kernel=False + ) execute_result, error_stacktrace, stream_text = result if error_stacktrace is not None: ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace)) @@ -492,10 +477,7 @@ async def _call(self, command, timeout=None, session_id=None): return status, ret @tool_api - async def run(self, - command: str, - timeout: Optional[int] = None, - session_id: Optional[str] = None) -> ActionReturn: + async def run(self, command: str, timeout: Optional[int] = None, session_id: Optional[str] = None) -> ActionReturn: r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. Args: @@ -516,8 +498,7 @@ async def run(self, # text=result['text'], image=result.get('image', [])[0]) tool_return.state = ActionStatusCode.SUCCESS else: - tool_return.errmsg = result.get('text', '') if isinstance( - result, dict) else result + tool_return.errmsg = result.get('text', '') if isinstance(result, dict) else result tool_return.state = ActionStatusCode.API_ERROR return tool_return @@ -549,6 +530,7 @@ def escape_ansi(line): def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): import PIL.Image + image_file = str(uuid.uuid4()) + '.png' local_image_file = os.path.join(work_dir, image_file) diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py new file mode 100644 index 0000000..9d5683a --- /dev/null +++ b/lagent/actions/mcp_client.py @@ -0,0 +1,191 @@ +import asyncio +import logging +from contextlib import AsyncExitStack +from typing import Literal, TypeAlias + +from lagent.actions.base_action import BaseAction +from lagent.actions.parser import JsonParser, ParseError +from lagent.schema import ActionReturn, ActionStatusCode + +ServerType: TypeAlias = Literal["stdio", "sse", "http"] + +logger = logging.getLogger(__name__) +_loop = None + + +def _get_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except Exception: + logger.warning('Can not found event loop in current thread. Create a new event loop.') + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + if event_loop.is_running(): + global _loop + if _loop: + return _loop + + from threading import Thread + + def _start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + event_loop = asyncio.new_event_loop() + Thread(target=_start_loop, args=(event_loop,), daemon=True).start() + _loop = event_loop + return event_loop + + +class AsyncMCPClient(BaseAction): + """Model Context Protocol (MCP) Client for asynchronous communication with MCP servers. + + Args: + name (str): The name of the action. Make sure it is unique among all actions. + server_type (ServerType): The type of MCP server to connect to. Options are "stdio", "sse", or "http". + **server_params: Additional parameters for the server connection, which may include: + - For stdio servers: + - command (str): The command to run the MCP server. + - args (list, optional): Additional arguments for the command. + - env (dict, optional): Environment variables for the command. + - cwd (str, optional): Current working directory for the command. + - For sse servers: + - url (str): The URL of the MCP server. + - headers (dict, optional): Headers to include in the request. + - timeout (int, optional): Timeout for the request. + - sse_read_timeout (int, optional): Timeout for reading SSE events. + - For http servers: + - url (str): The URL of the MCP server. + - headers (dict, optional): Headers to include in the request. + - timeout (int, optional): Timeout for the request. + - sse_read_timeout (int, optional): Timeout for reading SSE events. + - terminate_on_close (bool, optional): Whether to terminate the connection on close. + """ + + is_stateful = True + + def __init__(self, name: str, server_type: ServerType, **server_params): + self._is_toolkit = True + self._sessions: dict = {} + self.server_type = server_type + self.server_params = server_params + self.exit_stack = AsyncExitStack() + # get the list of tools from the MCP server + loop = _get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe(self.list_tools(), loop) + tools = fut.result() + else: + tools = loop.run_until_complete(self.list_tools()) + self._api_names = {tool.name for tool in tools} + super().__init__( + description=dict( + name=name, + api_list=[ + { + 'name': tool.name, + 'description': tool.description, + 'parameters': [ + {'name': k, 'type': v['type'].upper(), 'description': v.get('description', '')} + for k, v in tool.inputSchema['properties'].items() + ], + 'required': tool.inputSchema.get('required', []), + } + for tool in tools + ], + ), + parser=JsonParser, + ) + + async def initialize(self, session_id): + """Initialize the MCP client and connect to the server.""" + if session_id in self._sessions: + return self._sessions[session_id] + + from mcp import ClientSession, StdioServerParameters + + if self.server_type == "stdio": + from mcp.client.stdio import stdio_client + + logger.info( + f"Connecting to stdio MCP server with command: {self.server_params['command']} " + f"{self.server_params.get('args', [])}" + ) + + client_kwargs = {"command": self.server_params["command"]} + for key in ["args", "env", "cwd"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + server_params = StdioServerParameters(**client_kwargs) + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + elif self.server_type == "sse": + from mcp.client.sse import sse_client + + logger.info(f"Connecting to SSE MCP server at: {self.server_params['url']}") + + client_kwargs = {"url": self.server_params["url"]} + for key in ["headers", "timeout", "sse_read_timeout"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs)) + elif self.server_type == "http": + from mcp.client.streamable_http import streamablehttp_client + + logger.info(f"Connecting to StreamableHTTP MCP server at: {self.server_params['url']}") + + client_kwargs = {"url": self.server_params["url"]} + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs)) + else: + raise ValueError(f"Unsupported server type: {self.server_type}") + + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self._sessions[session_id] = session + return session + + async def cleanup(self): + await self.exit_stack.aclose() + + async def list_tools(self, session_id=0) -> list: + session = await self.initialize(session_id=session_id) + return (await session.list_tools()).tools + + def __del__(self): + loop = _get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe(self.cleanup(), loop) + fut.result() + else: + loop.run_until_complete(self.cleanup()) + + async def __call__(self, inputs: str, name: str) -> ActionReturn: + session_id = inputs.pop('session_id', 0) if isinstance(inputs, dict) else 0 + fallback_args = {'inputs': inputs, 'name': name} + if name not in self._api_names: + return ActionReturn( + fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR + ) + try: + inputs = self._parser.parse_inputs(inputs, name) + except ParseError as exc: + return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + try: + session = await self.initialize(session_id) + outputs = await session.call_tool(name, inputs) + outputs = outputs.content[0].text + except Exception as exc: + return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = inputs + if not action_return.type: + action_return.type = self.name + else: + result = self._parser.parse_outputs(outputs) + action_return = ActionReturn(inputs, type=self.name, result=result) + return action_return diff --git a/lagent/hooks/action_preprocessor.py b/lagent/hooks/action_preprocessor.py index 51083aa..09f6ec6 100644 --- a/lagent/hooks/action_preprocessor.py +++ b/lagent/hooks/action_preprocessor.py @@ -1,3 +1,4 @@ +import inspect from copy import deepcopy from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall @@ -11,17 +12,20 @@ class ActionPreprocessor(Hook): """ def before_action(self, executor, message, session_id): - assert isinstance(message.formatted, FunctionCall) or ( - isinstance(message.formatted, dict) and 'name' in message.content - and 'parameters' in message.formatted) or ( + assert ( + isinstance(message.formatted, FunctionCall) + or ( + isinstance(message.formatted, dict) and 'name' in message.content and 'parameters' in message.formatted + ) + or ( 'action' in message.formatted and 'parameters' in message.formatted['action'] - and 'name' in message.formatted['action']) + and 'name' in message.formatted['action'] + ) + ) if isinstance(message.formatted, dict): - name = message.formatted.get('name', - message.formatted['action']['name']) - parameters = message.formatted.get( - 'parameters', message.formatted['action']['parameters']) + name = message.formatted.get('name', message.formatted['action']['name']) + parameters = message.formatted.get('parameters', message.formatted['action']['parameters']) else: name = message.formatted.name parameters = message.formatted.parameters @@ -48,15 +52,28 @@ def __init__(self, code_parameter: str = 'command'): def before_action(self, executor, message, session_id): message = deepcopy(message) - assert isinstance(message.formatted, dict) and set( - message.formatted).issuperset( - {'tool_type', 'thought', 'action', 'status'}) - if isinstance(message.formatted['action'], str): - # encapsulate code interpreter arguments - action_name = next(iter(executor.actions)) - parameters = {self.code_parameter: message.formatted['action']} - if action_name in ['AsyncIPythonInterpreter']: - parameters['session_id'] = session_id - message.formatted['action'] = dict( - name=action_name, parameters=parameters) + assert isinstance(message.formatted, dict) and set(message.formatted).issuperset( + {'tool_type', 'thought', 'action', 'status'} + ) + if message.formatted['tool_type'] == 'interpreter' and isinstance(message.formatted['action'], str): + for action in executor.actions.values(): + if hasattr(action, 'run') and callable(action.run): + param = inspect.signature(action.run).parameters + if self.code_parameter in param: + # encapsulate code interpreter arguments + message.formatted['action'] = dict( + name=action.name, parameters={self.code_parameter: message.formatted['action']} + ) + break + else: + raise ValueError( + f"Action '{message.formatted['action']}' is not supported by any action in the executor." + ) + tool_call = message.formatted['action'] + if ( + isinstance(tool_call, dict) + and isinstance(tool_call.get('parameters', {}), dict) + and executor.actions[tool_call['name'].split('.')[0]].is_stateful + ): + tool_call['parameters']['session_id'] = session_id return super().before_action(executor, message, session_id) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index ac0b85c..14930de 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -12,7 +12,8 @@ jsonschema jupyter==1.0.0 jupyter_client==8.6.2 jupyter_core==5.7.2 -pydantic==2.6.4 +mcp +pydantic requests tenacity termcolor