Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 96 additions & 15 deletions src/fast_agent/agents/mcp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import fnmatch
import time
from abc import ABC
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -35,8 +36,8 @@

from fast_agent.agents.agent_types import AgentConfig, AgentType
from fast_agent.agents.llm_agent import DEFAULT_CAPABILITIES
from fast_agent.agents.tool_agent import ToolAgent
from fast_agent.constants import HUMAN_INPUT_TOOL_NAME
from fast_agent.agents.tool_agent import ToolAgent, ToolTimingInfo
from fast_agent.constants import FORCE_SEQUENTIAL_TOOL_CALLS, HUMAN_INPUT_TOOL_NAME
from fast_agent.core.exceptions import PromptExitError
from fast_agent.core.logging.logger import get_logger
from fast_agent.interfaces import FastAgentLLMProtocol
Expand Down Expand Up @@ -776,14 +777,12 @@ async def with_resource(

async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended:
"""Override ToolAgent's run_tools to use MCP tools via aggregator."""
import time

if not request.tool_calls:
self.logger.warning("No tool calls found in request", data=request)
return PromptMessageExtended(role="user", tool_results={})

tool_results: dict[str, CallToolResult] = {}
tool_timings: dict[str, float] = {} # Track timing for each tool call
tool_timings: dict[str, ToolTimingInfo] = {}
tool_loop_error: str | None = None

# Cache available tool names exactly as advertised to the LLM for display/highlighting
Expand All @@ -804,8 +803,15 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
# Cache namespaced tools for routing/metadata
namespaced_tools = self._aggregator._namespaced_tool_map

# Process each tool call using our aggregator
for correlation_id, tool_request in request.tool_calls.items():
tool_call_items = list(request.tool_calls.items())
should_parallel = (
(not FORCE_SEQUENTIAL_TOOL_CALLS) and len(tool_call_items) > 1
)

planned_calls: list[dict[str, Any]] = []

# Plan each tool call using our aggregator
for correlation_id, tool_request in tool_call_items:
tool_name = tool_request.params.name
tool_args = tool_request.params.arguments or {}
# correlation_id is the tool_use_id from the LLM
Expand Down Expand Up @@ -894,21 +900,96 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
metadata=metadata,
)

planned_calls.append(
{
"correlation_id": correlation_id,
"tool_name": tool_name,
"tool_args": tool_args,
"display_tool_name": display_tool_name,
"namespaced_tool": namespaced_tool,
"candidate_namespaced_tool": candidate_namespaced_tool,
}
)

if should_parallel and planned_calls:
async def run_one(call: dict[str, Any]) -> tuple[str, CallToolResult, float]:
start_time = time.perf_counter()
result = await self.call_tool(
call["tool_name"], call["tool_args"], call["correlation_id"]
)
end_time = time.perf_counter()
return call["correlation_id"], result, round((end_time - start_time) * 1000, 2)

results = await asyncio.gather(
*(run_one(call) for call in planned_calls), return_exceptions=True
)

for i, item in enumerate(results):
call = planned_calls[i]
correlation_id = call["correlation_id"]
display_tool_name = call["display_tool_name"]
namespaced_tool = call["namespaced_tool"]
candidate_namespaced_tool = call["candidate_namespaced_tool"]

if isinstance(item, Exception):
self.logger.error(f"MCP tool {display_tool_name} failed: {item}")
result = CallToolResult(
content=[TextContent(type="text", text=f"Error: {str(item)}")],
isError=True,
)
duration_ms = 0.0
else:
_, result, duration_ms = item

tool_results[correlation_id] = result
tool_timings[correlation_id] = {
"timing_ms": duration_ms,
"transport_channel": getattr(result, "transport_channel", None),
}

skybridge_config = None
skybridge_tool = namespaced_tool or candidate_namespaced_tool
if skybridge_tool:
try:
skybridge_config = await self._aggregator.get_skybridge_config(
skybridge_tool.server_name
)
except Exception:
skybridge_config = None

if not getattr(result, "_suppress_display", False):
self.display.show_tool_result(
name=self._name,
result=result,
tool_name=display_tool_name,
skybridge_config=skybridge_config,
timing_ms=duration_ms,
)

return self._finalize_tool_results(
tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error
)

for call in planned_calls:
correlation_id = call["correlation_id"]
tool_name = call["tool_name"]
tool_args = call["tool_args"]
display_tool_name = call["display_tool_name"]
namespaced_tool = call["namespaced_tool"]
candidate_namespaced_tool = call["candidate_namespaced_tool"]

try:
# Track timing for tool execution
start_time = time.perf_counter()
result = await self.call_tool(tool_name, tool_args, correlation_id)
end_time = time.perf_counter()
duration_ms = round((end_time - start_time) * 1000, 2)

tool_results[correlation_id] = result
# Store timing and transport channel info
tool_timings[correlation_id] = {
"timing_ms": duration_ms,
"transport_channel": getattr(result, "transport_channel", None)
"transport_channel": getattr(result, "transport_channel", None),
}

# Show tool result (like ToolAgent does)
skybridge_config = None
skybridge_tool = namespaced_tool or candidate_namespaced_tool
if skybridge_tool:
Expand All @@ -922,7 +1003,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
result=result,
tool_name=display_tool_name,
skybridge_config=skybridge_config,
timing_ms=duration_ms, # Use local duration_ms variable for display
timing_ms=duration_ms,
)

self.logger.debug(f"MCP tool {display_tool_name} executed successfully")
Expand All @@ -933,11 +1014,11 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
isError=True,
)
tool_results[correlation_id] = error_result

# Show error result too (no need for skybridge config on errors)
self.display.show_tool_result(name=self._name, result=error_result)

return self._finalize_tool_results(tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error)
return self._finalize_tool_results(
tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error
)

def _prepare_tool_display(
self,
Expand Down
147 changes: 98 additions & 49 deletions src/fast_agent/agents/tool_agent.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
from typing import Any, Callable, Dict, List, Sequence
import asyncio
import time
from typing import Any, Callable, Dict, List, Sequence, TypedDict

from mcp.server.fastmcp.tools.base import Tool as FastMCPTool
from mcp.types import CallToolResult, ListToolsResult, Tool

from fast_agent.agents.agent_types import AgentConfig
from fast_agent.agents.llm_agent import LlmAgent
from fast_agent.agents.tool_runner import ToolRunner, ToolRunnerHooks
from fast_agent.constants import (
DEFAULT_MAX_ITERATIONS,
FAST_AGENT_ERROR_CHANNEL,
FORCE_SEQUENTIAL_TOOL_CALLS,
HUMAN_INPUT_TOOL_NAME,
)
from fast_agent.context import Context
from fast_agent.core.logging.logger import get_logger
from fast_agent.mcp.helpers.content_helpers import text_content
from fast_agent.tools.elicitation import get_elicitation_fastmcp_tool
from fast_agent.types import PromptMessageExtended, RequestParams
from fast_agent.types.llm_stop_reason import LlmStopReason

logger = get_logger(__name__)


class ToolTimingInfo(TypedDict):
"""Timing information for a single tool call."""

timing_ms: float
transport_channel: str | None


class ToolAgent(LlmAgent):
"""
A Tool Calling agent that uses FastMCP Tools for execution.
Expand Down Expand Up @@ -87,43 +96,25 @@ async def generate_impl(
if tools is None:
tools = (await self.list_tools()).tools

iterations = 0
max_iterations = request_params.max_iterations if request_params else DEFAULT_MAX_ITERATIONS
runner = ToolRunner(
agent=self,
messages=messages,
request_params=request_params,
tools=tools,
hooks=self._tool_runner_hooks(),
)
return await runner.until_done()

while True:
result = await super().generate_impl(
messages,
request_params=request_params,
tools=tools,
)
def _tool_runner_hooks(self) -> ToolRunnerHooks | None:
return None

if LlmStopReason.TOOL_USE == result.stop_reason:
tool_message = await self.run_tools(result)
error_channel_messages = (tool_message.channels or {}).get(FAST_AGENT_ERROR_CHANNEL)
if error_channel_messages:
tool_result_contents = [
content
for tool_result in (tool_message.tool_results or {}).values()
for content in tool_result.content
]
if tool_result_contents:
if result.content is None:
result.content = []
result.content.extend(tool_result_contents)
result.stop_reason = LlmStopReason.ERROR
break
if self.config.use_history:
messages = [tool_message]
else:
messages.extend([result, tool_message])
else:
break

iterations += 1
if iterations > max_iterations:
logger.warning("Max iterations reached, stopping tool loop")
break
return result
async def _tool_runner_llm_step(
self,
messages: list[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: list[Tool] | None = None,
) -> PromptMessageExtended:
return await super().generate_impl(messages, request_params=request_params, tools=tools)

# we take care of tool results, so skip displaying them
def show_user_message(self, message: PromptMessageExtended) -> None:
Expand All @@ -133,19 +124,22 @@ def show_user_message(self, message: PromptMessageExtended) -> None:

async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended:
"""Runs the tools in the request, and returns a new User message with the results"""
import time

if not request.tool_calls:
logger.warning("No tool calls found in request", data=request)
return PromptMessageExtended(role="user", tool_results={})

tool_results: dict[str, CallToolResult] = {}
tool_timings: dict[str, float] = {} # Track timing for each tool call
tool_timings: dict[str, ToolTimingInfo] = {}
tool_loop_error: str | None = None
# TODO -- use gather() for parallel results, update display
tool_schemas = (await self.list_tools()).tools
available_tools = [t.name for t in tool_schemas]
for correlation_id, tool_request in request.tool_calls.items():

tool_call_items = list(request.tool_calls.items())
should_parallel = (not FORCE_SEQUENTIAL_TOOL_CALLS) and len(tool_call_items) > 1

planned_calls: list[tuple[str, str, dict[str, Any]]] = []
for correlation_id, tool_request in tool_call_items:
tool_name = tool_request.params.name
tool_args = tool_request.params.arguments or {}

Expand All @@ -158,7 +152,61 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
tool_results=tool_results,
)
break
planned_calls.append((correlation_id, tool_name, tool_args))

if should_parallel and planned_calls:
for correlation_id, tool_name, tool_args in planned_calls:
highlight_index = None
try:
highlight_index = available_tools.index(tool_name)
except ValueError:
pass

self.display.show_tool_call(
name=self.name,
tool_args=tool_args,
bottom_items=available_tools,
tool_name=tool_name,
highlight_index=highlight_index,
max_item_length=12,
)

async def run_one(
correlation_id: str, tool_name: str, tool_args: dict[str, Any]
) -> tuple[str, CallToolResult, float]:
start_time = time.perf_counter()
result = await self.call_tool(tool_name, tool_args)
end_time = time.perf_counter()
return correlation_id, result, round((end_time - start_time) * 1000, 2)

results = await asyncio.gather(
*(run_one(cid, name, args) for cid, name, args in planned_calls),
return_exceptions=True,
)

for i, item in enumerate(results):
correlation_id, tool_name, _ = planned_calls[i]
if isinstance(item, Exception):
msg = f"Error: {str(item)}"
result = CallToolResult(content=[text_content(msg)], isError=True)
duration_ms = 0.0
else:
_, result, duration_ms = item

tool_results[correlation_id] = result
tool_timings[correlation_id] = {
"timing_ms": duration_ms,
"transport_channel": None,
}
self.display.show_tool_result(
name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms
)

return self._finalize_tool_results(
tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error
)

for correlation_id, tool_name, tool_args in planned_calls:
# Find the index of the current tool in available_tools for highlighting
highlight_index = None
try:
Expand All @@ -184,13 +232,14 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend

tool_results[correlation_id] = result
# Store timing info (transport_channel not available for local tools)
tool_timings[correlation_id] = {
"timing_ms": duration_ms,
"transport_channel": None
}
self.display.show_tool_result(name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms)
tool_timings[correlation_id] = {"timing_ms": duration_ms, "transport_channel": None}
self.display.show_tool_result(
name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms
)

return self._finalize_tool_results(tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error)
return self._finalize_tool_results(
tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error
)

def _mark_tool_loop_error(
self,
Expand All @@ -211,7 +260,7 @@ def _finalize_tool_results(
self,
tool_results: dict[str, CallToolResult],
*,
tool_timings: dict[str, dict[str, float | str | None]] | None = None,
tool_timings: dict[str, ToolTimingInfo] | None = None,
tool_loop_error: str | None = None,
) -> PromptMessageExtended:
import json
Expand Down
Loading
Loading