From f82e113347276bfc00a4205c1435d96ff7a378fd Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Wed, 17 Dec 2025 11:32:46 +0000 Subject: [PATCH 1/5] refactor to tool runner loop --- src/fast_agent/agents/mcp_agent.py | 107 ++++++++- src/fast_agent/agents/tool_agent.py | 134 +++++++---- src/fast_agent/agents/tool_runner.py | 227 ++++++++++++++++++ .../agents/workflow/agents_as_tools_agent.py | 27 ++- src/fast_agent/constants.py | 3 + .../agents/test_tool_runner_hooks.py | 117 +++++++++ 6 files changed, 549 insertions(+), 66 deletions(-) create mode 100644 src/fast_agent/agents/tool_runner.py create mode 100644 tests/unit/fast_agent/agents/test_tool_runner_hooks.py diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index 9f6880b6..93303f58 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -36,7 +36,7 @@ from fast_agent.agents.agent_types import AgentConfig, AgentType from fast_agent.agents.llm_agent import DEFAULT_CAPABILITIES from fast_agent.agents.tool_agent import ToolAgent -from fast_agent.constants import HUMAN_INPUT_TOOL_NAME +from fast_agent.constants import FORCE_SEQUENTIAL_TOOL_CALLS, HUMAN_INPUT_TOOL_NAME from fast_agent.core.exceptions import PromptExitError from fast_agent.core.logging.logger import get_logger from fast_agent.interfaces import FastAgentLLMProtocol @@ -776,6 +776,7 @@ async def with_resource( async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended: """Override ToolAgent's run_tools to use MCP tools via aggregator.""" + import asyncio import time if not request.tool_calls: @@ -783,7 +784,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend return PromptMessageExtended(role="user", tool_results={}) tool_results: dict[str, CallToolResult] = {} - tool_timings: dict[str, float] = {} # Track timing for each tool call + tool_timings: dict[str, dict[str, float | str | None]] = {} tool_loop_error: str | None = None # Cache available tool names exactly as advertised to the LLM for display/highlighting @@ -804,8 +805,15 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend # Cache namespaced tools for routing/metadata namespaced_tools = self._aggregator._namespaced_tool_map - # Process each tool call using our aggregator - for correlation_id, tool_request in request.tool_calls.items(): + tool_call_items = list(request.tool_calls.items()) + should_parallel = ( + (not FORCE_SEQUENTIAL_TOOL_CALLS) and len(tool_call_items) > 1 + ) + + planned_calls: list[dict[str, Any]] = [] + + # Plan each tool call using our aggregator + for correlation_id, tool_request in tool_call_items: tool_name = tool_request.params.name tool_args = tool_request.params.arguments or {} # correlation_id is the tool_use_id from the LLM @@ -894,21 +902,96 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend metadata=metadata, ) + planned_calls.append( + { + "correlation_id": correlation_id, + "tool_name": tool_name, + "tool_args": tool_args, + "display_tool_name": display_tool_name, + "namespaced_tool": namespaced_tool, + "candidate_namespaced_tool": candidate_namespaced_tool, + } + ) + + if should_parallel and planned_calls: + async def run_one(call: dict[str, Any]) -> tuple[str, CallToolResult, float]: + start_time = time.perf_counter() + result = await self.call_tool( + call["tool_name"], call["tool_args"], call["correlation_id"] + ) + end_time = time.perf_counter() + return call["correlation_id"], result, round((end_time - start_time) * 1000, 2) + + results = await asyncio.gather( + *(run_one(call) for call in planned_calls), return_exceptions=True + ) + + for i, item in enumerate(results): + call = planned_calls[i] + correlation_id = call["correlation_id"] + display_tool_name = call["display_tool_name"] + namespaced_tool = call["namespaced_tool"] + candidate_namespaced_tool = call["candidate_namespaced_tool"] + + if isinstance(item, Exception): + self.logger.error(f"MCP tool {display_tool_name} failed: {item}") + result = CallToolResult( + content=[TextContent(type="text", text=f"Error: {str(item)}")], + isError=True, + ) + duration_ms = 0.0 + else: + _, result, duration_ms = item + + tool_results[correlation_id] = result + tool_timings[correlation_id] = { + "timing_ms": duration_ms, + "transport_channel": getattr(result, "transport_channel", None), + } + + skybridge_config = None + skybridge_tool = namespaced_tool or candidate_namespaced_tool + if skybridge_tool: + try: + skybridge_config = await self._aggregator.get_skybridge_config( + skybridge_tool.server_name + ) + except Exception: + skybridge_config = None + + if not getattr(result, "_suppress_display", False): + self.display.show_tool_result( + name=self._name, + result=result, + tool_name=display_tool_name, + skybridge_config=skybridge_config, + timing_ms=duration_ms, + ) + + return self._finalize_tool_results( + tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error + ) + + for call in planned_calls: + correlation_id = call["correlation_id"] + tool_name = call["tool_name"] + tool_args = call["tool_args"] + display_tool_name = call["display_tool_name"] + namespaced_tool = call["namespaced_tool"] + candidate_namespaced_tool = call["candidate_namespaced_tool"] + try: - # Track timing for tool execution start_time = time.perf_counter() result = await self.call_tool(tool_name, tool_args, correlation_id) end_time = time.perf_counter() duration_ms = round((end_time - start_time) * 1000, 2) tool_results[correlation_id] = result - # Store timing and transport channel info tool_timings[correlation_id] = { "timing_ms": duration_ms, - "transport_channel": getattr(result, "transport_channel", None) + "transport_channel": getattr(result, "transport_channel", None), } - # Show tool result (like ToolAgent does) skybridge_config = None skybridge_tool = namespaced_tool or candidate_namespaced_tool if skybridge_tool: @@ -922,7 +1005,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend result=result, tool_name=display_tool_name, skybridge_config=skybridge_config, - timing_ms=duration_ms, # Use local duration_ms variable for display + timing_ms=duration_ms, ) self.logger.debug(f"MCP tool {display_tool_name} executed successfully") @@ -933,11 +1016,11 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend isError=True, ) tool_results[correlation_id] = error_result - - # Show error result too (no need for skybridge config on errors) self.display.show_tool_result(name=self._name, result=error_result) - return self._finalize_tool_results(tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error) + return self._finalize_tool_results( + tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error + ) def _prepare_tool_display( self, diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index 1878b155..e11910a8 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -5,9 +5,10 @@ from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.llm_agent import LlmAgent +from fast_agent.agents.tool_runner import ToolRunner, ToolRunnerHooks from fast_agent.constants import ( - DEFAULT_MAX_ITERATIONS, FAST_AGENT_ERROR_CHANNEL, + FORCE_SEQUENTIAL_TOOL_CALLS, HUMAN_INPUT_TOOL_NAME, ) from fast_agent.context import Context @@ -15,7 +16,6 @@ from fast_agent.mcp.helpers.content_helpers import text_content from fast_agent.tools.elicitation import get_elicitation_fastmcp_tool from fast_agent.types import PromptMessageExtended, RequestParams -from fast_agent.types.llm_stop_reason import LlmStopReason logger = get_logger(__name__) @@ -87,43 +87,25 @@ async def generate_impl( if tools is None: tools = (await self.list_tools()).tools - iterations = 0 - max_iterations = request_params.max_iterations if request_params else DEFAULT_MAX_ITERATIONS + runner = ToolRunner( + agent=self, + messages=messages, + request_params=request_params, + tools=tools, + hooks=self._tool_runner_hooks(), + ) + return await runner.until_done() - while True: - result = await super().generate_impl( - messages, - request_params=request_params, - tools=tools, - ) + def _tool_runner_hooks(self) -> ToolRunnerHooks | None: + return None - if LlmStopReason.TOOL_USE == result.stop_reason: - tool_message = await self.run_tools(result) - error_channel_messages = (tool_message.channels or {}).get(FAST_AGENT_ERROR_CHANNEL) - if error_channel_messages: - tool_result_contents = [ - content - for tool_result in (tool_message.tool_results or {}).values() - for content in tool_result.content - ] - if tool_result_contents: - if result.content is None: - result.content = [] - result.content.extend(tool_result_contents) - result.stop_reason = LlmStopReason.ERROR - break - if self.config.use_history: - messages = [tool_message] - else: - messages.extend([result, tool_message]) - else: - break - - iterations += 1 - if iterations > max_iterations: - logger.warning("Max iterations reached, stopping tool loop") - break - return result + async def _tool_runner_llm_step( + self, + messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + ) -> PromptMessageExtended: + return await super().generate_impl(messages, request_params=request_params, tools=tools) # we take care of tool results, so skip displaying them def show_user_message(self, message: PromptMessageExtended) -> None: @@ -133,6 +115,7 @@ def show_user_message(self, message: PromptMessageExtended) -> None: async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended: """Runs the tools in the request, and returns a new User message with the results""" + import asyncio import time if not request.tool_calls: @@ -140,12 +123,17 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend return PromptMessageExtended(role="user", tool_results={}) tool_results: dict[str, CallToolResult] = {} - tool_timings: dict[str, float] = {} # Track timing for each tool call + tool_timings: dict[str, dict[str, float | str | None]] = {} tool_loop_error: str | None = None # TODO -- use gather() for parallel results, update display tool_schemas = (await self.list_tools()).tools available_tools = [t.name for t in tool_schemas] - for correlation_id, tool_request in request.tool_calls.items(): + + tool_call_items = list(request.tool_calls.items()) + should_parallel = (not FORCE_SEQUENTIAL_TOOL_CALLS) and len(tool_call_items) > 1 + + planned_calls: list[tuple[str, str, dict[str, Any]]] = [] + for correlation_id, tool_request in tool_call_items: tool_name = tool_request.params.name tool_args = tool_request.params.arguments or {} @@ -158,6 +146,61 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend tool_results=tool_results, ) break + planned_calls.append((correlation_id, tool_name, tool_args)) + + if should_parallel and planned_calls: + for correlation_id, tool_name, tool_args in planned_calls: + highlight_index = None + try: + highlight_index = available_tools.index(tool_name) + except ValueError: + pass + + self.display.show_tool_call( + name=self.name, + tool_args=tool_args, + bottom_items=available_tools, + tool_name=tool_name, + highlight_index=highlight_index, + max_item_length=12, + ) + + async def run_one( + correlation_id: str, tool_name: str, tool_args: dict[str, Any] + ) -> tuple[str, CallToolResult, float]: + start_time = time.perf_counter() + result = await self.call_tool(tool_name, tool_args) + end_time = time.perf_counter() + return correlation_id, result, round((end_time - start_time) * 1000, 2) + + results = await asyncio.gather( + *(run_one(cid, name, args) for cid, name, args in planned_calls), + return_exceptions=True, + ) + + for i, item in enumerate(results): + correlation_id, tool_name, _ = planned_calls[i] + if isinstance(item, Exception): + msg = f"Error: {str(item)}" + result = CallToolResult(content=[text_content(msg)], isError=True) + duration_ms = 0.0 + else: + _, result, duration_ms = item + + tool_results[correlation_id] = result + tool_timings[correlation_id] = { + "timing_ms": duration_ms, + "transport_channel": None, + } + self.display.show_tool_result( + name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms + ) + + return self._finalize_tool_results( + tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error + ) + + for correlation_id, tool_name, tool_args in planned_calls: # Find the index of the current tool in available_tools for highlighting highlight_index = None @@ -184,13 +227,14 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend tool_results[correlation_id] = result # Store timing info (transport_channel not available for local tools) - tool_timings[correlation_id] = { - "timing_ms": duration_ms, - "transport_channel": None - } - self.display.show_tool_result(name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms) + tool_timings[correlation_id] = {"timing_ms": duration_ms, "transport_channel": None} + self.display.show_tool_result( + name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms + ) - return self._finalize_tool_results(tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error) + return self._finalize_tool_results( + tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error + ) def _mark_tool_loop_error( self, diff --git a/src/fast_agent/agents/tool_runner.py b/src/fast_agent/agents/tool_runner.py new file mode 100644 index 00000000..2e5c49c8 --- /dev/null +++ b/src/fast_agent/agents/tool_runner.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Protocol, + Union, +) + +from mcp.types import TextContent + +from fast_agent.constants import DEFAULT_MAX_ITERATIONS, FAST_AGENT_ERROR_CHANNEL +from fast_agent.types import PromptMessageExtended, RequestParams +from fast_agent.types.llm_stop_reason import LlmStopReason + +if TYPE_CHECKING: + from mcp import Tool + + from fast_agent.agents.llm_decorator import LlmDecorator # noqa: F401 + + +class _ToolLoopAgent(Protocol): + config: Any + + async def _tool_runner_llm_step( + self, + messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + ) -> PromptMessageExtended: ... + + async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended: ... + + +@dataclass(frozen=True) +class ToolRunnerHooks: + """ + Optional hook points for customizing the tool loop. + + These hooks are intentionally low-level and mutation-friendly: they can + inspect and modify the agent history (via agent.load_message_history), + tweak request params, or append extra messages via the runner. + """ + + before_llm_call: ( + Callable[["ToolRunner", list[PromptMessageExtended]], Awaitable[None]] | None + ) = None + after_llm_call: Callable[["ToolRunner", PromptMessageExtended], Awaitable[None]] | None = None + before_tool_call: Callable[["ToolRunner", PromptMessageExtended], Awaitable[None]] | None = None + after_tool_call: Callable[["ToolRunner", PromptMessageExtended], Awaitable[None]] | None = None + + +class ToolRunner: + """ + Async-iterable tool runner. + + Yields assistant messages (LLM responses). If the response requests tools, + a tool response is prepared and sent on the next iteration. + """ + + def __init__( + self, + *, + agent: _ToolLoopAgent, + messages: list[PromptMessageExtended], + request_params: RequestParams | None, + tools: list[Tool] | None, + hooks: ToolRunnerHooks | None = None, + ) -> None: + self._agent = agent + self._delta_messages: list[PromptMessageExtended] = list(messages) + self._request_params = request_params + self._tools = tools + self._hooks = hooks or ToolRunnerHooks() + + self._iteration = 0 + self._done = False + self._last_message: PromptMessageExtended | None = None + + self._pending_tool_request: PromptMessageExtended | None = None + self._pending_tool_response: PromptMessageExtended | None = None + + def __aiter__(self) -> "ToolRunner": + return self + + async def __anext__(self) -> PromptMessageExtended: + if self._done: + raise StopAsyncIteration + + await self._ensure_tool_response_staged() + if self._done: + raise StopAsyncIteration + + full_history = self._full_history_for_next_call() + if self._hooks.before_llm_call is not None: + await self._hooks.before_llm_call(self, full_history) + full_history = self._full_history_for_next_call() + + assistant_message = await self._agent._tool_runner_llm_step( + self._delta_messages, + request_params=self._request_params, + tools=self._tools, + ) + + self._last_message = assistant_message + if self._hooks.after_llm_call is not None: + await self._hooks.after_llm_call(self, assistant_message) + + if assistant_message.stop_reason == LlmStopReason.TOOL_USE: + self._pending_tool_request = assistant_message + else: + self._done = True + + return assistant_message + + async def until_done(self) -> PromptMessageExtended: + last: PromptMessageExtended | None = None + async for message in self: + last = message + if last is None: + raise RuntimeError("ToolRunner produced no messages") + return last + + async def generate_tool_call_response(self) -> PromptMessageExtended | None: + if self._pending_tool_request is None: + return None + if self._pending_tool_response is not None: + return self._pending_tool_response + + if self._hooks.before_tool_call is not None: + await self._hooks.before_tool_call(self, self._pending_tool_request) + + tool_message = await self._agent.run_tools(self._pending_tool_request) + self._pending_tool_response = tool_message + + if self._hooks.after_tool_call is not None: + await self._hooks.after_tool_call(self, tool_message) + + self._stage_tool_response(tool_message) + self._pending_tool_request = None + + return tool_message + + def set_request_params(self, params: RequestParams) -> None: + self._request_params = params + + def append_messages(self, *messages: Union[str, PromptMessageExtended]) -> None: + for message in messages: + if isinstance(message, str): + self._delta_messages.append( + PromptMessageExtended( + role="user", + content=[TextContent(type="text", text=message)], + ) + ) + else: + self._delta_messages.append(message) + + @property + def messages(self) -> list[PromptMessageExtended]: + return self._full_history_for_next_call() + + @property + def iteration(self) -> int: + return self._iteration + + @property + def is_done(self) -> bool: + return self._done + + @property + def last_message(self) -> PromptMessageExtended | None: + return self._last_message + + @property + def has_pending_tool_response(self) -> bool: + return self._pending_tool_request is not None + + def _stage_tool_response(self, tool_message: PromptMessageExtended) -> None: + if getattr(self._agent.config, "use_history", True): + self._delta_messages = [tool_message] + else: + if self._last_message is not None: + self._delta_messages.append(self._last_message) + self._delta_messages.append(tool_message) + + async def _ensure_tool_response_staged(self) -> None: + if self._pending_tool_request is None: + return + + tool_message = await self.generate_tool_call_response() + if tool_message is None: + return + + error_channel_messages = (tool_message.channels or {}).get(FAST_AGENT_ERROR_CHANNEL) + if error_channel_messages and self._last_message is not None: + tool_result_contents = [ + content + for tool_result in (tool_message.tool_results or {}).values() + for content in tool_result.content + ] + if tool_result_contents: + if self._last_message.content is None: + self._last_message.content = [] + self._last_message.content.extend(tool_result_contents) + self._last_message.stop_reason = LlmStopReason.ERROR + self._done = True + return + + self._iteration += 1 + max_iterations = ( + self._request_params.max_iterations + if self._request_params is not None + else DEFAULT_MAX_ITERATIONS + ) + if self._iteration > max_iterations: + self._done = True + + def _full_history_for_next_call(self) -> list[PromptMessageExtended]: + agent = self._agent + if not hasattr(agent, "_prepare_llm_call"): + return list(self._delta_messages) + call_ctx = getattr(agent, "_prepare_llm_call")(self._delta_messages, self._request_params) + return list(call_ctx.full_history) diff --git a/src/fast_agent/agents/workflow/agents_as_tools_agent.py b/src/fast_agent/agents/workflow/agents_as_tools_agent.py index 444a08fe..583f36f0 100644 --- a/src/fast_agent/agents/workflow/agents_as_tools_agent.py +++ b/src/fast_agent/agents/workflow/agents_as_tools_agent.py @@ -197,7 +197,7 @@ async def coordinator(): pass from mcp.types import CallToolResult from fast_agent.agents.mcp_agent import McpAgent -from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL +from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL, FORCE_SEQUENTIAL_TOOL_CALLS from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt from fast_agent.mcp.helpers.content_helpers import get_text, text_content @@ -638,7 +638,6 @@ async def _run_child_tools( call_descriptors: list[dict[str, Any]] = [] descriptor_by_id: dict[str, dict[str, Any]] = {} - tasks: list[asyncio.Task] = [] id_list: list[str] = [] for correlation_id, tool_request in request.tool_calls.items(): @@ -796,15 +795,25 @@ async def call_with_instance_name( ) ) - for i, cid in enumerate(id_list, 1): - tool_name = descriptor_by_id[cid]["tool"] - tool_args = descriptor_by_id[cid]["args"] - tasks.append(asyncio.create_task(call_with_instance_name(tool_name, tool_args, i, cid))) - self._show_parallel_tool_calls(call_descriptors) - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) + results: list[CallToolResult | Exception] = [] + if id_list: + if FORCE_SEQUENTIAL_TOOL_CALLS: + for i, cid in enumerate(id_list, 1): + tool_name = descriptor_by_id[cid]["tool"] + tool_args = descriptor_by_id[cid]["args"] + try: + results.append(await call_with_instance_name(tool_name, tool_args, i, cid)) + except Exception as exc: + results.append(exc) + else: + tasks = [] + for i, cid in enumerate(id_list, 1): + tool_name = descriptor_by_id[cid]["tool"] + tool_args = descriptor_by_id[cid]["args"] + tasks.append(asyncio.create_task(call_with_instance_name(tool_name, tool_args, i, cid))) + results = await asyncio.gather(*tasks, return_exceptions=True) for i, result in enumerate(results): correlation_id = id_list[i] if isinstance(result, Exception): diff --git a/src/fast_agent/constants.py b/src/fast_agent/constants.py index dd63761f..e4f11a93 100644 --- a/src/fast_agent/constants.py +++ b/src/fast_agent/constants.py @@ -10,6 +10,9 @@ FAST_AGENT_REMOVED_METADATA_CHANNEL = "fast-agent-removed-meta" FAST_AGENT_TIMING = "fast-agent-timing" FAST_AGENT_TOOL_TIMING = "fast-agent-tool-timing" + +FORCE_SEQUENTIAL_TOOL_CALLS = False +"""Force tool execution to run sequentially even when multiple tool calls are present.""" # should we have MAX_TOOL_CALLS instead to constrain by number of tools rather than turns...? DEFAULT_MAX_ITERATIONS = 99 """Maximum number of User/Assistant turns to take""" diff --git a/tests/unit/fast_agent/agents/test_tool_runner_hooks.py b/tests/unit/fast_agent/agents/test_tool_runner_hooks.py new file mode 100644 index 00000000..a1c70a9c --- /dev/null +++ b/tests/unit/fast_agent/agents/test_tool_runner_hooks.py @@ -0,0 +1,117 @@ +import pytest +from mcp import CallToolRequest +from mcp.types import CallToolRequestParams, Tool + +from fast_agent.agents.agent_types import AgentConfig +from fast_agent.agents.tool_agent import ToolAgent +from fast_agent.agents.tool_runner import ToolRunnerHooks +from fast_agent.core.prompt import Prompt +from fast_agent.llm.internal.passthrough import PassthroughLLM +from fast_agent.llm.request_params import RequestParams +from fast_agent.mcp.helpers.content_helpers import get_text +from fast_agent.mcp.prompt_message_extended import PromptMessageExtended +from fast_agent.types.llm_stop_reason import LlmStopReason + + +def tool_one() -> int: + return 1 + + +def tool_two() -> int: + return 2 + + +class TwoStepToolUseLlm(PassthroughLLM): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.calls: list[list[str]] = [] + self._turn = 0 + + async def _apply_prompt_provider_specific( + self, + multipart_messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + is_template: bool = False, + ) -> PromptMessageExtended: + self._turn += 1 + self.calls.append( + [ + get_text(block) or "" + for msg in multipart_messages + for block in (msg.content or []) + if get_text(block) + ] + ) + + if self._turn == 1: + tool_calls = { + "id_one": CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="tool_one", arguments={}), + ), + "id_two": CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="tool_two", arguments={}), + ), + } + return Prompt.assistant( + "use tools", + stop_reason=LlmStopReason.TOOL_USE, + tool_calls=tool_calls, + ) + + return Prompt.assistant("done", stop_reason=LlmStopReason.END_TURN) + + +class HookedToolAgent(ToolAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.events: list[str] = [] + self._injected = False + + def _tool_runner_hooks(self) -> ToolRunnerHooks | None: + async def before_llm_call(runner, messages): + self.events.append(f"before_llm_call:{runner.iteration}") + if not self._injected: + runner.append_messages("extra from hook") + self._injected = True + + async def after_llm_call(runner, message): + self.events.append(f"after_llm_call:{message.stop_reason}") + + async def before_tool_call(runner, message): + self.events.append(f"before_tool_call:{len(message.tool_calls or {})}") + + async def after_tool_call(runner, message): + self.events.append(f"after_tool_call:{len(message.tool_results or {})}") + + return ToolRunnerHooks( + before_llm_call=before_llm_call, + after_llm_call=after_llm_call, + before_tool_call=before_tool_call, + after_tool_call=after_tool_call, + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_tool_runner_hooks_fire_and_can_inject_messages(): + llm = TwoStepToolUseLlm() + agent = HookedToolAgent(AgentConfig("hooked"), [tool_one, tool_two]) + agent._llm = llm + + result = await agent.generate("hi") + assert result.last_text() == "done" + + assert any("extra from hook" in entry for entry in llm.calls[0]) + + assert agent.events == [ + "before_llm_call:0", + f"after_llm_call:{LlmStopReason.TOOL_USE}", + "before_tool_call:2", + "after_tool_call:2", + "before_llm_call:1", + f"after_llm_call:{LlmStopReason.END_TURN}", + ] + From 522382a451732999a31773081a34b568dbc95c4f Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Wed, 17 Dec 2025 11:49:03 +0000 Subject: [PATCH 2/5] tidy up --- src/fast_agent/agents/mcp_agent.py | 4 +--- src/fast_agent/agents/tool_agent.py | 5 ++--- src/fast_agent/agents/tool_runner.py | 5 +---- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index 93303f58..2d3195fd 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -7,6 +7,7 @@ import asyncio import fnmatch +import time from abc import ABC from typing import ( TYPE_CHECKING, @@ -776,9 +777,6 @@ async def with_resource( async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended: """Override ToolAgent's run_tools to use MCP tools via aggregator.""" - import asyncio - import time - if not request.tool_calls: self.logger.warning("No tool calls found in request", data=request) return PromptMessageExtended(role="user", tool_results={}) diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index e11910a8..de2dd1d6 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -1,3 +1,5 @@ +import asyncio +import time from typing import Any, Callable, Dict, List, Sequence from mcp.server.fastmcp.tools.base import Tool as FastMCPTool @@ -115,9 +117,6 @@ def show_user_message(self, message: PromptMessageExtended) -> None: async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended: """Runs the tools in the request, and returns a new User message with the results""" - import asyncio - import time - if not request.tool_calls: logger.warning("No tool calls found in request", data=request) return PromptMessageExtended(role="user", tool_results={}) diff --git a/src/fast_agent/agents/tool_runner.py b/src/fast_agent/agents/tool_runner.py index 2e5c49c8..80809ffa 100644 --- a/src/fast_agent/agents/tool_runner.py +++ b/src/fast_agent/agents/tool_runner.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from mcp import Tool - from fast_agent.agents.llm_decorator import LlmDecorator # noqa: F401 - class _ToolLoopAgent(Protocol): config: Any @@ -94,10 +92,9 @@ async def __anext__(self) -> PromptMessageExtended: if self._done: raise StopAsyncIteration - full_history = self._full_history_for_next_call() if self._hooks.before_llm_call is not None: - await self._hooks.before_llm_call(self, full_history) full_history = self._full_history_for_next_call() + await self._hooks.before_llm_call(self, full_history) assistant_message = await self._agent._tool_runner_llm_step( self._delta_messages, From 235fd0bde8f6c1fcc4ebe195d2dea65fbed8b744 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:58:35 +0000 Subject: [PATCH 3/5] refactor 1 --- src/fast_agent/agents/mcp_agent.py | 4 ++-- src/fast_agent/agents/tool_agent.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index 2d3195fd..97e4bd69 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -36,7 +36,7 @@ from fast_agent.agents.agent_types import AgentConfig, AgentType from fast_agent.agents.llm_agent import DEFAULT_CAPABILITIES -from fast_agent.agents.tool_agent import ToolAgent +from fast_agent.agents.tool_agent import ToolAgent, ToolTimingInfo from fast_agent.constants import FORCE_SEQUENTIAL_TOOL_CALLS, HUMAN_INPUT_TOOL_NAME from fast_agent.core.exceptions import PromptExitError from fast_agent.core.logging.logger import get_logger @@ -782,7 +782,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend return PromptMessageExtended(role="user", tool_results={}) tool_results: dict[str, CallToolResult] = {} - tool_timings: dict[str, dict[str, float | str | None]] = {} + tool_timings: dict[str, ToolTimingInfo] = {} tool_loop_error: str | None = None # Cache available tool names exactly as advertised to the LLM for display/highlighting diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index de2dd1d6..f97ccc40 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, Callable, Dict, List, Sequence +from typing import Any, Callable, Dict, List, Sequence, TypedDict from mcp.server.fastmcp.tools.base import Tool as FastMCPTool from mcp.types import CallToolResult, ListToolsResult, Tool @@ -22,6 +22,13 @@ logger = get_logger(__name__) +class ToolTimingInfo(TypedDict): + """Timing information for a single tool call.""" + + timing_ms: float + transport_channel: str | None + + class ToolAgent(LlmAgent): """ A Tool Calling agent that uses FastMCP Tools for execution. @@ -122,7 +129,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend return PromptMessageExtended(role="user", tool_results={}) tool_results: dict[str, CallToolResult] = {} - tool_timings: dict[str, dict[str, float | str | None]] = {} + tool_timings: dict[str, ToolTimingInfo] = {} tool_loop_error: str | None = None # TODO -- use gather() for parallel results, update display tool_schemas = (await self.list_tools()).tools @@ -254,7 +261,7 @@ def _finalize_tool_results( self, tool_results: dict[str, CallToolResult], *, - tool_timings: dict[str, dict[str, float | str | None]] | None = None, + tool_timings: dict[str, ToolTimingInfo] | None = None, tool_loop_error: str | None = None, ) -> PromptMessageExtended: import json From 261ea93ab06146ec7a27fd1622a45b71846911c2 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Thu, 18 Dec 2025 23:44:41 +0000 Subject: [PATCH 4/5] type safety. history management still split between decorator/runner. --- src/fast_agent/agents/tool_agent.py | 1 - src/fast_agent/agents/tool_runner.py | 24 ++++++++++-------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index f97ccc40..bb5bb415 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -207,7 +207,6 @@ async def run_one( ) for correlation_id, tool_name, tool_args in planned_calls: - # Find the index of the current tool in available_tools for highlighting highlight_index = None try: diff --git a/src/fast_agent/agents/tool_runner.py b/src/fast_agent/agents/tool_runner.py index 80809ffa..22e8d1e2 100644 --- a/src/fast_agent/agents/tool_runner.py +++ b/src/fast_agent/agents/tool_runner.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Protocol, @@ -20,8 +19,12 @@ from mcp import Tool +class _AgentConfig(Protocol): + use_history: bool + + class _ToolLoopAgent(Protocol): - config: Any + config: _AgentConfig async def _tool_runner_llm_step( self, @@ -93,8 +96,7 @@ async def __anext__(self) -> PromptMessageExtended: raise StopAsyncIteration if self._hooks.before_llm_call is not None: - full_history = self._full_history_for_next_call() - await self._hooks.before_llm_call(self, full_history) + await self._hooks.before_llm_call(self, self._delta_messages) assistant_message = await self._agent._tool_runner_llm_step( self._delta_messages, @@ -157,8 +159,9 @@ def append_messages(self, *messages: Union[str, PromptMessageExtended]) -> None: self._delta_messages.append(message) @property - def messages(self) -> list[PromptMessageExtended]: - return self._full_history_for_next_call() + def delta_messages(self) -> list[PromptMessageExtended]: + """Messages to be sent in the next LLM call (not full history).""" + return self._delta_messages @property def iteration(self) -> int: @@ -177,7 +180,7 @@ def has_pending_tool_response(self) -> bool: return self._pending_tool_request is not None def _stage_tool_response(self, tool_message: PromptMessageExtended) -> None: - if getattr(self._agent.config, "use_history", True): + if self._agent.config.use_history: self._delta_messages = [tool_message] else: if self._last_message is not None: @@ -215,10 +218,3 @@ async def _ensure_tool_response_staged(self) -> None: ) if self._iteration > max_iterations: self._done = True - - def _full_history_for_next_call(self) -> list[PromptMessageExtended]: - agent = self._agent - if not hasattr(agent, "_prepare_llm_call"): - return list(self._delta_messages) - call_ctx = getattr(agent, "_prepare_llm_call")(self._delta_messages, self._request_params) - return list(call_ctx.full_history) From 9e9d123baf8d0cbdb9024b6af0f0a80ad9c664ea Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Fri, 19 Dec 2025 00:02:05 +0000 Subject: [PATCH 5/5] reset cached tool response in loop: test --- src/fast_agent/agents/tool_runner.py | 1 + .../agents/test_tool_runner_hooks.py | 78 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/src/fast_agent/agents/tool_runner.py b/src/fast_agent/agents/tool_runner.py index 22e8d1e2..02424f1a 100644 --- a/src/fast_agent/agents/tool_runner.py +++ b/src/fast_agent/agents/tool_runner.py @@ -110,6 +110,7 @@ async def __anext__(self) -> PromptMessageExtended: if assistant_message.stop_reason == LlmStopReason.TOOL_USE: self._pending_tool_request = assistant_message + self._pending_tool_response = None # Clear cache for new request else: self._done = True diff --git a/tests/unit/fast_agent/agents/test_tool_runner_hooks.py b/tests/unit/fast_agent/agents/test_tool_runner_hooks.py index a1c70a9c..6c035af0 100644 --- a/tests/unit/fast_agent/agents/test_tool_runner_hooks.py +++ b/tests/unit/fast_agent/agents/test_tool_runner_hooks.py @@ -115,3 +115,81 @@ async def test_tool_runner_hooks_fire_and_can_inject_messages(): f"after_llm_call:{LlmStopReason.END_TURN}", ] + +# Track tool invocations globally for the regression test +_tool_invocations: list[str] = [] + + +def tracked_tool_a() -> str: + _tool_invocations.append("tool_a") + return "result_a" + + +def tracked_tool_b() -> str: + _tool_invocations.append("tool_b") + return "result_b" + + +class TwoRoundToolUseLlm(PassthroughLLM): + """LLM that returns tool_use twice before completing.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._turn = 0 + + async def _apply_prompt_provider_specific( + self, + multipart_messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + is_template: bool = False, + ) -> PromptMessageExtended: + self._turn += 1 + + if self._turn == 1: + # First round: call tool_a + return Prompt.assistant( + "calling tool_a", + stop_reason=LlmStopReason.TOOL_USE, + tool_calls={ + "call_1": CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="tracked_tool_a", arguments={}), + ), + }, + ) + + if self._turn == 2: + # Second round: call tool_b + return Prompt.assistant( + "calling tool_b", + stop_reason=LlmStopReason.TOOL_USE, + tool_calls={ + "call_2": CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="tracked_tool_b", arguments={}), + ), + }, + ) + + return Prompt.assistant("done", stop_reason=LlmStopReason.END_TURN) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_two_tool_use_rounds_both_execute(): + """Regression test: ensure second tool-use round executes new tools, not cached response.""" + _tool_invocations.clear() + + llm = TwoRoundToolUseLlm() + agent = ToolAgent(AgentConfig("test"), [tracked_tool_a, tracked_tool_b]) + agent._llm = llm + + result = await agent.generate("hi") + assert result.last_text() == "done" + + # Both tools must have been called - if caching bug exists, only tool_a would be called + assert _tool_invocations == ["tool_a", "tool_b"], ( + f"Expected both tools to execute, got: {_tool_invocations}" + ) +