diff --git a/src/mcp_cli/chat/chat_context.py b/src/mcp_cli/chat/chat_context.py index 362e254e..097328fc 100644 --- a/src/mcp_cli/chat/chat_context.py +++ b/src/mcp_cli/chat/chat_context.py @@ -236,7 +236,32 @@ def _generate_system_prompt(self) -> None: tools_for_prompt = [ tool.to_llm_format().to_dict() for tool in self.internal_tools ] - self._system_prompt = generate_system_prompt(tools_for_prompt) + server_tool_groups = self._build_server_tool_groups() + self._system_prompt = generate_system_prompt( + tools=tools_for_prompt, + server_tool_groups=server_tool_groups, + ) + + def _build_server_tool_groups(self) -> list[dict[str, Any]]: + """Build server-to-tools grouping for the system prompt.""" + if not self.server_info: + return [] + + # Group tools by server namespace + server_tools: dict[str, list[str]] = {} + for tool_name, namespace in self.tool_to_server_map.items(): + server_tools.setdefault(namespace, []).append(tool_name) + + groups = [] + for server in self.server_info: + tools = server_tools.get(server.namespace, []) + if tools: + groups.append({ + "name": server.name, + "description": server.display_description, + "tools": sorted(tools), + }) + return groups async def _initialize_tools(self) -> None: """Initialize tool discovery and adaptation.""" diff --git a/src/mcp_cli/chat/conversation.py b/src/mcp_cli/chat/conversation.py index 71e84f84..9431908b 100644 --- a/src/mcp_cli/chat/conversation.py +++ b/src/mcp_cli/chat/conversation.py @@ -532,9 +532,13 @@ async def _handle_streaming_completion( try: # stream_response returns dict, convert to CompletionResponse + messages_for_api = [ + msg.to_dict() for msg in self.context.conversation_history + ] + messages_for_api = self._validate_tool_messages(messages_for_api) completion_dict = await streaming_handler.stream_response( client=self.context.client, - messages=[msg.to_dict() for msg in self.context.conversation_history], + messages=messages_for_api, tools=tools, ) @@ -573,6 +577,7 @@ async def _handle_regular_completion( messages_as_dicts = [ msg.to_dict() for msg in self.context.conversation_history ] + messages_as_dicts = self._validate_tool_messages(messages_as_dicts) completion_dict = await self.context.client.create_completion( messages=messages_as_dicts, tools=tools, @@ -588,6 +593,7 @@ async def _handle_regular_completion( messages_as_dicts = [ msg.to_dict() for msg in self.context.conversation_history ] + messages_as_dicts = self._validate_tool_messages(messages_as_dicts) completion_dict = await self.context.client.create_completion( messages=messages_as_dicts ) @@ -632,6 +638,58 @@ async def _load_tools(self): self.context.openai_tools = [] self.context.tool_name_mapping = {} + @staticmethod + def _validate_tool_messages(messages: list[dict]) -> list[dict]: + """Ensure every assistant tool_call_id has a matching tool result. + + Defense-in-depth: repairs orphaned tool_calls before sending to the API. + Without this, OpenAI returns a 400 error: + "An assistant message with 'tool_calls' must be followed by tool messages + responding to each 'tool_call_id'." + + Args: + messages: List of message dicts about to be sent to the API. + + Returns: + The message list, with placeholder tool results inserted for any + orphaned tool_call_ids. + """ + repaired: list[dict] = [] + i = 0 + while i < len(messages): + msg = messages[i] + repaired.append(msg) + + if msg.get("role") == "assistant" and msg.get("tool_calls"): + # Collect expected tool_call_ids from this assistant message + expected_ids = set() + for tc in msg["tool_calls"]: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + if tc_id: + expected_ids.add(tc_id) + + # Scan following messages for matching tool results + j = i + 1 + found_ids: set[str] = set() + while j < len(messages) and messages[j].get("role") == "tool": + tid = messages[j].get("tool_call_id") + if tid: + found_ids.add(tid) + j += 1 + + # Insert placeholders for any missing tool results + missing = expected_ids - found_ids + for mid in missing: + log.warning(f"Repairing orphaned tool_call_id: {mid}") + repaired.append({ + "role": "tool", + "tool_call_id": mid, + "content": "Tool call did not complete.", + }) + + i += 1 + return repaired + def _register_user_literals_from_history(self) -> int: """Extract and register numeric literals from recent user messages. diff --git a/src/mcp_cli/chat/system_prompt.py b/src/mcp_cli/chat/system_prompt.py index fc1d6916..04cbae77 100644 --- a/src/mcp_cli/chat/system_prompt.py +++ b/src/mcp_cli/chat/system_prompt.py @@ -2,7 +2,31 @@ import os -def generate_system_prompt(tools=None): +def _build_server_section(server_tool_groups): + """Build the server/tool categorization section for the system prompt.""" + if not server_tool_groups: + return "" + + lines = [ + "", + "**CONNECTED SERVERS & AVAILABLE TOOLS:**", + "", + "You have access to tools from the following servers. Consider using tools", + "from ALL relevant servers when answering a query.", + "", + ] + for group in server_tool_groups: + name = group.get("name", "unknown") + desc = group.get("description", "") + tools = group.get("tools", []) + tool_list = ", ".join(tools) + lines.append(f"- **{name}** ({desc}): {tool_list}") + + lines.append("") + return "\n".join(lines) + + +def generate_system_prompt(tools=None, server_tool_groups=None): """Generate a concise system prompt for the assistant. Note: Tool definitions are passed via the API's tools parameter, @@ -10,6 +34,11 @@ def generate_system_prompt(tools=None): When dynamic tools mode is enabled (MCP_CLI_DYNAMIC_TOOLS=1), generates a special prompt explaining the tool discovery workflow. + + Args: + tools: List of tool definitions (dicts or ToolInfo objects). + server_tool_groups: Optional list of dicts with server/tool grouping, + each containing {"name", "description", "tools"}. """ # Check if dynamic tools mode is enabled dynamic_mode = os.environ.get("MCP_CLI_DYNAMIC_TOOLS") == "1" @@ -20,9 +49,13 @@ def generate_system_prompt(tools=None): # Count tools for the prompt (tools may be ToolInfo objects or dicts) tool_count = len(tools) if tools else 0 + # Build server/tool categorization section + server_section = _build_server_section(server_tool_groups) + system_prompt = f"""You are an intelligent assistant with access to {tool_count} tools to help solve user queries effectively. Use the available tools when appropriate to accomplish tasks. Tools are provided via the API and you can call them as needed. +{server_section} **GENERAL GUIDELINES:** diff --git a/src/mcp_cli/chat/tool_processor.py b/src/mcp_cli/chat/tool_processor.py index 5483aa6c..344b4005 100644 --- a/src/mcp_cli/chat/tool_processor.py +++ b/src/mcp_cli/chat/tool_processor.py @@ -66,6 +66,9 @@ def __init__( self._call_metadata: dict[str, dict[str, Any]] = {} self._cancelled = False + # Track which tool_call_ids have received results (for orphan detection) + self._result_ids_added: set[str] = set() + # Give the context a back-pointer for Ctrl-C cancellation # Note: This is the one place we set an attribute on context context.tool_processor = self @@ -98,6 +101,7 @@ async def process_tool_calls( # Reset state self._call_metadata.clear() self._cancelled = False + self._result_ids_added = set() # Add assistant message with all tool calls BEFORE executing self._add_assistant_message_with_tool_calls(tool_calls, reasoning_content) @@ -105,249 +109,248 @@ async def process_tool_calls( # Convert LLM tool calls to CTP format and check confirmations ctp_calls: list[CTPToolCall] = [] - for idx, call in enumerate(tool_calls): - if getattr(self.ui_manager, "interrupt_requested", False): - self._cancelled = True - break - - # Extract tool call details - llm_tool_name, raw_arguments, call_id = self._extract_tool_call_info( - call, idx - ) - - # Map to execution name - execution_tool_name = name_mapping.get(llm_tool_name, llm_tool_name) - - # Get display name - special handling for dynamic tool call_tool - display_name = execution_tool_name - display_arguments = raw_arguments - - # For dynamic tools, extract the actual tool name from call_tool - if execution_tool_name == "call_tool": - # Parse arguments to get the real tool name - parsed_args = self._parse_arguments(raw_arguments) - if "tool_name" in parsed_args: - actual_tool = parsed_args["tool_name"] - # Show as "call_tool → actual_tool_name" - display_name = f"call_tool → {actual_tool}" - # Filter out tool_name from displayed args to reduce noise - display_arguments = { - k: v for k, v in parsed_args.items() if k != "tool_name" - } - - if hasattr(self.context, "get_display_name_for_tool"): - # Only apply name mapping if not already a dynamic tool - if not execution_tool_name.startswith("call_tool"): - display_name = self.context.get_display_name_for_tool( - execution_tool_name - ) - - # Show tool call in UI - try: - self.ui_manager.print_tool_call(display_name, display_arguments) - except Exception as ui_exc: - log.warning(f"UI display error (non-fatal): {ui_exc}") - - # Handle user confirmation - if self._should_confirm_tool(execution_tool_name): - confirmed = self.ui_manager.do_confirm_tool_execution( - tool_name=display_name, arguments=raw_arguments - ) - if not confirmed: - setattr(self.ui_manager, "interrupt_requested", True) - self._add_cancelled_tool_to_history( - llm_tool_name, call_id, raw_arguments - ) + try: + for idx, call in enumerate(tool_calls): + if getattr(self.ui_manager, "interrupt_requested", False): self._cancelled = True break - # Parse arguments - arguments = self._parse_arguments(raw_arguments) - - # DEBUG: Log exactly what the model sent for this tool call - log.info(f"TOOL CALL FROM MODEL: {llm_tool_name} id={call_id}") - log.info(f" raw_arguments: {raw_arguments}") - log.info(f" parsed_arguments: {arguments}") - - # Get actual tool name for checks (for call_tool, it's the inner tool) - actual_tool_for_checks = execution_tool_name - if execution_tool_name == "call_tool" and "tool_name" in arguments: - actual_tool_for_checks = arguments["tool_name"] - - # GENERIC VALIDATION: Reject tool calls with None arguments - # This catches cases where the model emits placeholders or incomplete calls - none_args = [ - k for k, v in arguments.items() if v is None and k != "tool_name" - ] - if none_args: - error_msg = ( - f"INVALID_ARGS: Tool '{actual_tool_for_checks}' called with None values " - f"for: {', '.join(none_args)}. Please provide actual values." - ) - log.warning(error_msg) - output.warning(f"⚠ {error_msg}") - self._add_tool_result_to_history( - llm_tool_name, - call_id, - f"**Error**: {error_msg}\n\nPlease retry with actual parameter values.", + # Extract tool call details + llm_tool_name, raw_arguments, call_id = self._extract_tool_call_info( + call, idx ) - continue - # Check $vN references in arguments (dataflow validation) - tool_state = get_tool_state() - ref_check = tool_state.check_references(arguments) - if not ref_check.valid: - log.warning( - f"Missing references in {actual_tool_for_checks}: {ref_check.message}" - ) - output.warning(f"⚠ {ref_check.message}") - # Add error to history instead of executing - self._add_tool_result_to_history( - llm_tool_name, - call_id, - f"**Blocked**: {ref_check.message}\n\n" - f"{tool_state.format_bindings_for_model()}", - ) - continue - - # Check for ungrounded calls (numeric args without $vN refs) - # Skip discovery tools - they don't need grounded numeric inputs - # Skip idempotent math tools - they should be allowed to compute with any literals - # Use SoftBlock repair system: attempt rebind → symbolic fallback → ask user - is_math_tool = tool_state.is_idempotent_math_tool(actual_tool_for_checks) - if ( - not tool_state.is_discovery_tool(execution_tool_name) - and not is_math_tool - ): - ungrounded_check = tool_state.check_ungrounded_call( - actual_tool_for_checks, arguments - ) - if ungrounded_check.is_ungrounded: - # Log args for observability (important for debugging) - log.info( - f"Ungrounded call to {actual_tool_for_checks} with args: {arguments}" - ) + # Map to execution name + execution_tool_name = name_mapping.get(llm_tool_name, llm_tool_name) + + # Get display name - special handling for dynamic tool call_tool + display_name = execution_tool_name + display_arguments = raw_arguments + + # For dynamic tools, extract the actual tool name from call_tool + if execution_tool_name == "call_tool": + # Parse arguments to get the real tool name + parsed_args = self._parse_arguments(raw_arguments) + if "tool_name" in parsed_args: + actual_tool = parsed_args["tool_name"] + # Show as "call_tool → actual_tool_name" + display_name = f"call_tool → {actual_tool}" + # Filter out tool_name from displayed args to reduce noise + display_arguments = { + k: v for k, v in parsed_args.items() if k != "tool_name" + } - # Check if this tool should have auto-rebound applied - # Parameterized tools (normal_cdf, sqrt, etc.) should NOT be rebound - # because each call with different args has different semantics - if not tool_state.should_auto_rebound(actual_tool_for_checks): - # For parameterized tools, check preconditions first - # This blocks premature calls before any values are computed - precond_ok, precond_error = tool_state.check_tool_preconditions( - actual_tool_for_checks, arguments + if hasattr(self.context, "get_display_name_for_tool"): + # Only apply name mapping if not already a dynamic tool + if not execution_tool_name.startswith("call_tool"): + display_name = self.context.get_display_name_for_tool( + execution_tool_name ) - if not precond_ok: - log.warning( - f"Precondition failed for {actual_tool_for_checks}" - ) - output.warning( - f"⚠ Precondition failed for {actual_tool_for_checks}" - ) - self._add_tool_result_to_history( - llm_tool_name, call_id, f"**Blocked**: {precond_error}" - ) - continue - # Preconditions met - log and allow execution - display_args = { - k: v for k, v in arguments.items() if k != "tool_name" - } - log.info( - f"Allowing parameterized tool {actual_tool_for_checks} with args: {display_args}" + # Show tool call in UI + try: + self.ui_manager.print_tool_call(display_name, display_arguments) + except Exception as ui_exc: + log.warning(f"UI display error (non-fatal): {ui_exc}") + + # Handle user confirmation + if self._should_confirm_tool(execution_tool_name): + confirmed = self.ui_manager.do_confirm_tool_execution( + tool_name=display_name, arguments=raw_arguments + ) + if not confirmed: + setattr(self.ui_manager, "interrupt_requested", True) + self._add_cancelled_tool_to_history( + llm_tool_name, call_id, raw_arguments ) - output.info(f"→ {actual_tool_for_checks} args: {display_args}") - # Fall through to execution - else: - # For other tools, try to repair using SoftBlock system - should_proceed, repaired_args, fallback_response = ( - tool_state.try_soft_block_repair( - actual_tool_for_checks, - arguments, - SoftBlockReason.UNGROUNDED_ARGS, - ) + self._cancelled = True + break + + # Parse arguments + arguments = self._parse_arguments(raw_arguments) + + # DEBUG: Log exactly what the model sent for this tool call + log.info(f"TOOL CALL FROM MODEL: {llm_tool_name} id={call_id}") + log.info(f" raw_arguments: {raw_arguments}") + log.info(f" parsed_arguments: {arguments}") + + # Get actual tool name for checks (for call_tool, it's the inner tool) + actual_tool_for_checks = execution_tool_name + if execution_tool_name == "call_tool" and "tool_name" in arguments: + actual_tool_for_checks = arguments["tool_name"] + + # GENERIC VALIDATION: Reject tool calls with None arguments + # This catches cases where the model emits placeholders or incomplete calls + none_args = [ + k for k, v in arguments.items() if v is None and k != "tool_name" + ] + if none_args: + error_msg = ( + f"INVALID_ARGS: Tool '{actual_tool_for_checks}' called with None values " + f"for: {', '.join(none_args)}. Please provide actual values." + ) + log.warning(error_msg) + output.warning(f"⚠ {error_msg}") + self._add_tool_result_to_history( + llm_tool_name, + call_id, + f"**Error**: {error_msg}\n\nPlease retry with actual parameter values.", + ) + continue + + # Check $vN references in arguments (dataflow validation) + tool_state = get_tool_state() + ref_check = tool_state.check_references(arguments) + if not ref_check.valid: + log.warning( + f"Missing references in {actual_tool_for_checks}: {ref_check.message}" + ) + output.warning(f"⚠ {ref_check.message}") + # Add error to history instead of executing + self._add_tool_result_to_history( + llm_tool_name, + call_id, + f"**Blocked**: {ref_check.message}\n\n" + f"{tool_state.format_bindings_for_model()}", + ) + continue + + # Check for ungrounded calls (numeric args without $vN refs) + # Skip discovery tools - they don't need grounded numeric inputs + # Skip idempotent math tools - they should be allowed to compute with any literals + # Use SoftBlock repair system: attempt rebind → symbolic fallback → ask user + is_math_tool = tool_state.is_idempotent_math_tool(actual_tool_for_checks) + if ( + not tool_state.is_discovery_tool(execution_tool_name) + and not is_math_tool + ): + ungrounded_check = tool_state.check_ungrounded_call( + actual_tool_for_checks, arguments + ) + if ungrounded_check.is_ungrounded: + # Log args for observability (important for debugging) + log.info( + f"Ungrounded call to {actual_tool_for_checks} with args: {arguments}" ) - if should_proceed and repaired_args: - # Rebind succeeded - use repaired arguments - log.info( - f"Auto-repaired ungrounded call to {actual_tool_for_checks}: " - f"{arguments} -> {repaired_args}" + # Check if this tool should have auto-rebound applied + # Parameterized tools (normal_cdf, sqrt, etc.) should NOT be rebound + # because each call with different args has different semantics + if not tool_state.should_auto_rebound(actual_tool_for_checks): + # For parameterized tools, check preconditions first + # This blocks premature calls before any values are computed + precond_ok, precond_error = tool_state.check_tool_preconditions( + actual_tool_for_checks, arguments ) - output.info( - f"↻ Auto-rebound arguments for {actual_tool_for_checks}" - ) - arguments = repaired_args - elif fallback_response: - # Symbolic fallback - return helpful response instead of blocking - # Show visible annotation for observability - log.info(f"Symbolic fallback for {actual_tool_for_checks}") - output.info( - f"⏸ [analysis] required_input_missing for {actual_tool_for_checks}" - ) - self._add_tool_result_to_history( - llm_tool_name, call_id, fallback_response + if not precond_ok: + log.warning( + f"Precondition failed for {actual_tool_for_checks}" + ) + output.warning( + f"⚠ Precondition failed for {actual_tool_for_checks}" + ) + self._add_tool_result_to_history( + llm_tool_name, call_id, f"**Blocked**: {precond_error}" + ) + continue + + # Preconditions met - log and allow execution + display_args = { + k: v for k, v in arguments.items() if k != "tool_name" + } + log.info( + f"Allowing parameterized tool {actual_tool_for_checks} with args: {display_args}" ) - continue + output.info(f"→ {actual_tool_for_checks} args: {display_args}") + # Fall through to execution else: - # All repairs failed - add error to history - log.warning( - f"Could not repair ungrounded call to {actual_tool_for_checks}" - ) - self._add_tool_result_to_history( - llm_tool_name, - call_id, - f"Cannot proceed with `{actual_tool_for_checks}`: " - f"arguments require computed values.\n\n" - f"{tool_state.format_bindings_for_model()}", + # For other tools, try to repair using SoftBlock system + should_proceed, repaired_args, fallback_response = ( + tool_state.try_soft_block_repair( + actual_tool_for_checks, + arguments, + SoftBlockReason.UNGROUNDED_ARGS, + ) ) - continue - - # Check per-tool call limit using the guard (handles exemptions for math/discovery) - # per_tool_cap=0 means "disabled/unlimited" (see RuntimeLimits presets) - per_tool_result = tool_state.check_per_tool_limit(actual_tool_for_checks) - if tool_state.limits.per_tool_cap > 0 and per_tool_result.blocked: - log.warning(f"Tool {actual_tool_for_checks} blocked by per-tool limit") - output.warning( - f"⚠ Tool {actual_tool_for_checks} - {per_tool_result.reason}" - ) - self._add_tool_result_to_history( - llm_tool_name, - call_id, - per_tool_result.reason or "Per-tool limit reached", - ) - continue - - # Resolve $vN references in arguments (substitute actual values) - resolved_arguments = tool_state.resolve_references(arguments) - - # Store metadata for callbacks - self._call_metadata[call_id] = { - "llm_tool_name": llm_tool_name, - "execution_tool_name": execution_tool_name, - "display_name": display_name, - "arguments": resolved_arguments, # Use resolved arguments - "raw_arguments": raw_arguments, - } - - # Create CTP ToolCall with resolved arguments - ctp_calls.append( - CTPToolCall( - id=call_id, - tool=execution_tool_name, - arguments=resolved_arguments, + + if should_proceed and repaired_args: + # Rebind succeeded - use repaired arguments + log.info( + f"Auto-repaired ungrounded call to {actual_tool_for_checks}: " + f"{arguments} -> {repaired_args}" + ) + output.info( + f"↻ Auto-rebound arguments for {actual_tool_for_checks}" + ) + arguments = repaired_args + elif fallback_response: + # Symbolic fallback - return helpful response instead of blocking + # Show visible annotation for observability + log.info(f"Symbolic fallback for {actual_tool_for_checks}") + output.info( + f"⏸ [analysis] required_input_missing for {actual_tool_for_checks}" + ) + self._add_tool_result_to_history( + llm_tool_name, call_id, fallback_response + ) + continue + else: + # All repairs failed - add error to history + log.warning( + f"Could not repair ungrounded call to {actual_tool_for_checks}" + ) + self._add_tool_result_to_history( + llm_tool_name, + call_id, + f"Cannot proceed with `{actual_tool_for_checks}`: " + f"arguments require computed values.\n\n" + f"{tool_state.format_bindings_for_model()}", + ) + continue + + # Check per-tool call limit using the guard (handles exemptions for math/discovery) + # per_tool_cap=0 means "disabled/unlimited" (see RuntimeLimits presets) + per_tool_result = tool_state.check_per_tool_limit(actual_tool_for_checks) + if tool_state.limits.per_tool_cap > 0 and per_tool_result.blocked: + log.warning(f"Tool {actual_tool_for_checks} blocked by per-tool limit") + output.warning( + f"⚠ Tool {actual_tool_for_checks} - {per_tool_result.reason}" + ) + self._add_tool_result_to_history( + llm_tool_name, + call_id, + per_tool_result.reason or "Per-tool limit reached", + ) + continue + + # Resolve $vN references in arguments (substitute actual values) + resolved_arguments = tool_state.resolve_references(arguments) + + # Store metadata for callbacks + self._call_metadata[call_id] = { + "llm_tool_name": llm_tool_name, + "execution_tool_name": execution_tool_name, + "display_name": display_name, + "arguments": resolved_arguments, # Use resolved arguments + "raw_arguments": raw_arguments, + } + + # Create CTP ToolCall with resolved arguments + ctp_calls.append( + CTPToolCall( + id=call_id, + tool=execution_tool_name, + arguments=resolved_arguments, + ) ) - ) - if self._cancelled or not ctp_calls: - await self._finish_tool_calls() - return + if self._cancelled or not ctp_calls: + return - if self.tool_manager is None: - raise RuntimeError("No tool manager available for tool execution") + if self.tool_manager is None: + raise RuntimeError("No tool manager available for tool execution") - # Execute tools in parallel using ToolManager's streaming API - try: + # Execute tools in parallel using ToolManager's streaming API async for result in self.tool_manager.stream_execute_tools( calls=ctp_calls, on_tool_start=self._on_tool_start, @@ -356,10 +359,14 @@ async def process_tool_calls( await self._on_tool_result(result) if self._cancelled: break # type: ignore[unreachable] + except asyncio.CancelledError: pass - - await self._finish_tool_calls() + finally: + # SAFETY NET: Ensure every tool_call_id has a matching result. + # This prevents OpenAI 400 errors from orphaned tool_call_ids. + self._ensure_all_tool_results(tool_calls) + await self._finish_tool_calls() def cancel_running_tasks(self) -> None: """Cancel running tool execution.""" @@ -520,6 +527,27 @@ def _track_transport_failures(self, success: bool, error: str | None) -> None: else: self._consecutive_transport_failures = 0 + def _ensure_all_tool_results(self, tool_calls: list[Any]) -> None: + """Ensure every tool_call_id in the assistant message has a matching result. + + This is a safety net that prevents OpenAI 400 errors caused by orphaned + tool_call_ids. If any tool_call_id is missing a result (due to guard + exceptions, silent failures, or interrupted execution), a placeholder + error result is added. + """ + for idx, call in enumerate(tool_calls): + llm_tool_name, _, call_id = self._extract_tool_call_info(call, idx) + if call_id not in self._result_ids_added: + log.warning( + f"Missing tool result for {llm_tool_name} ({call_id}), " + "adding error placeholder" + ) + self._add_tool_result_to_history( + llm_tool_name, + call_id, + "Tool execution was interrupted or failed to complete.", + ) + async def _finish_tool_calls(self) -> None: """Signal UI that all tool calls are complete.""" if hasattr(self.ui_manager, "finish_tool_calls") and callable( @@ -781,6 +809,7 @@ def _add_tool_result_to_history( tool_call_id=call_id, ) self.context.inject_tool_message(tool_msg) + self._result_ids_added.add(call_id) log.debug(f"Added tool result to conversation history: {llm_tool_name}") except Exception as e: log.error(f"Error updating conversation history: {e}") diff --git a/tests/chat/test_chat_context.py b/tests/chat/test_chat_context.py index 22ad9378..75756d57 100644 --- a/tests/chat/test_chat_context.py +++ b/tests/chat/test_chat_context.py @@ -111,7 +111,7 @@ def dummy_tool_manager(): def chat_context(dummy_tool_manager, monkeypatch): # Use deterministic system prompt monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) # Mock ModelManager to avoid model discovery issues @@ -339,7 +339,7 @@ async def test_str(chat_context): async def test_context_manager(dummy_tool_manager, monkeypatch): """Test async context manager.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -403,7 +403,7 @@ async def test_refresh_after_model_change(chat_context): async def test_create_with_provider_only(dummy_tool_manager, monkeypatch): """Test ChatContext.create with provider only.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -430,7 +430,7 @@ async def test_create_with_provider_only(dummy_tool_manager, monkeypatch): async def test_create_with_model_only(dummy_tool_manager, monkeypatch): """Test ChatContext.create with model only.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -455,7 +455,7 @@ async def test_create_with_model_only(dummy_tool_manager, monkeypatch): async def test_create_with_provider_and_api_settings(dummy_tool_manager, monkeypatch): """Test ChatContext.create with provider and API settings.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -485,7 +485,7 @@ async def test_create_with_provider_model_and_api_settings( ): """Test ChatContext.create with all settings.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -517,7 +517,7 @@ async def test_initialize_failure(dummy_tool_manager, monkeypatch): # Patch generate_system_prompt to avoid issues monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) mock_manager = Mock(spec=ModelManager) @@ -544,7 +544,7 @@ async def raise_error(): async def test_regenerate_system_prompt_insert(dummy_tool_manager, monkeypatch): """Test regenerate_system_prompt when no system message exists.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -574,7 +574,7 @@ async def test_regenerate_system_prompt_insert(dummy_tool_manager, monkeypatch): async def test_context_manager_failure(dummy_tool_manager, monkeypatch): """Test async context manager handles initialization failure.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -604,7 +604,7 @@ async def fail_init(): async def test_adapt_tools_without_get_adapted_tools(dummy_tool_manager, monkeypatch): """Test _adapt_tools_for_provider fallback when get_adapted_tools_for_llm not available.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -650,7 +650,7 @@ async def get_tools_for_llm(self): async def test_adapt_tools_exception_fallback(dummy_tool_manager, monkeypatch): """Test _adapt_tools_for_provider handles exceptions.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -699,7 +699,7 @@ async def get_tools_for_llm(self): async def test_initialize_no_tools_warning(monkeypatch, capsys): """Test initialize prints warning when no tools available.""" monkeypatch.setattr( - "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools: "SYS_PROMPT" + "mcp_cli.chat.chat_context.generate_system_prompt", lambda tools=None, **kw: "SYS_PROMPT" ) from unittest.mock import Mock @@ -743,3 +743,49 @@ async def test_find_tool_by_name_partial_match(chat_context): tool = chat_context.find_tool_by_name("other.tool1") assert tool is not None assert tool.name == "tool1" + + +# --------------------------------------------------------------------------- +# Server/tool grouping tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_build_server_tool_groups(chat_context): + """_build_server_tool_groups returns correct grouping.""" + await chat_context.initialize() + groups = chat_context._build_server_tool_groups() + + assert len(groups) == 2 + + names = {g["name"] for g in groups} + assert names == {"srv1", "srv2"} + + for group in groups: + assert "name" in group + assert "description" in group + assert "tools" in group + assert len(group["tools"]) >= 1 + + +@pytest.mark.asyncio +async def test_build_server_tool_groups_empty(dummy_tool_manager, monkeypatch): + """_build_server_tool_groups returns empty list when no server_info.""" + monkeypatch.setattr( + "mcp_cli.chat.chat_context.generate_system_prompt", + lambda tools=None, **kw: "SYS_PROMPT", + ) + + from unittest.mock import Mock + from mcp_cli.model_management import ModelManager + + mock_manager = Mock(spec=ModelManager) + mock_manager.get_client.return_value = None + mock_manager.get_active_provider.return_value = "mock" + mock_manager.get_active_model.return_value = "mock-model" + + ctx = ChatContext.create( + tool_manager=dummy_tool_manager, model_manager=mock_manager + ) + # Don't initialize — server_info is empty + assert ctx._build_server_tool_groups() == [] diff --git a/tests/chat/test_conversation.py b/tests/chat/test_conversation.py index 5417128e..1426ee1c 100644 --- a/tests/chat/test_conversation.py +++ b/tests/chat/test_conversation.py @@ -1544,3 +1544,189 @@ async def test_extracts_bindings_with_values(self): # Should have called extract_bindings_from_text mock_tool_state.extract_bindings_from_text.assert_called_once() + + +class TestValidateToolMessages: + """Tests for _validate_tool_messages defense-in-depth validation.""" + + def test_valid_messages_unchanged(self): + """Messages with matching tool results are not modified.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "echo", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "OK"}, + {"role": "assistant", "content": "Done."}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + assert result == messages + + def test_orphaned_tool_call_gets_placeholder(self): + """An assistant message with a tool_call_id missing a result gets a placeholder.""" + messages = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_missing", "type": "function", "function": {"name": "fetch", "arguments": "{}"}}, + ], + }, + # No tool result for call_missing! + {"role": "assistant", "content": "Something else."}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + + # Should have 4 messages now: user, assistant+tool_calls, tool placeholder, assistant + assert len(result) == 4 + assert result[2]["role"] == "tool" + assert result[2]["tool_call_id"] == "call_missing" + assert "did not complete" in result[2]["content"] + + def test_multiple_tool_calls_partial_results(self): + """When an assistant message has multiple tool_calls and only some have results.""" + messages = [ + {"role": "user", "content": "Do two things"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}}, + {"id": "call_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_a", "content": "Result A"}, + # call_b result is missing + ] + + result = ConversationProcessor._validate_tool_messages(messages) + + # Should have 4 messages: user, assistant+tool_calls, tool result A, placeholder for B + assert len(result) == 4 + tool_results = [m for m in result if m.get("role") == "tool"] + assert len(tool_results) == 2 + tool_call_ids = {m["tool_call_id"] for m in tool_results} + assert tool_call_ids == {"call_a", "call_b"} + + def test_no_tool_calls_unchanged(self): + """Messages without tool_calls pass through unchanged.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + assert result == messages + + def test_empty_messages(self): + """Empty message list returns empty.""" + assert ConversationProcessor._validate_tool_messages([]) == [] + + def test_multiple_sequential_tool_rounds(self): + """Multiple assistant→tool rounds are all validated correctly.""" + messages = [ + {"role": "user", "content": "Do things"}, + # Round 1: valid + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_r1", "type": "function", "function": {"name": "a", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_r1", "content": "Done A"}, + # Round 2: orphaned + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_r2", "type": "function", "function": {"name": "b", "arguments": "{}"}}, + ], + }, + # Missing tool result for call_r2! + {"role": "assistant", "content": "Final answer."}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + + # Should have 6 messages: user, asst+tc, tool, asst+tc, PLACEHOLDER, asst + assert len(result) == 6 + assert result[4]["role"] == "tool" + assert result[4]["tool_call_id"] == "call_r2" + assert "did not complete" in result[4]["content"] + # Round 1 should be untouched + assert result[2]["tool_call_id"] == "call_r1" + assert result[2]["content"] == "Done A" + + def test_all_results_missing(self): + """When ALL tool results are missing from a multi-call assistant message.""" + messages = [ + {"role": "user", "content": "Run both"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_x", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": "call_y", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + }, + # No tool results at all — next message is user + {"role": "user", "content": "What happened?"}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + + # Should have 5 messages: user, asst+tc, placeholder_x, placeholder_y, user + assert len(result) == 5 + placeholders = [m for m in result if m.get("role") == "tool"] + assert len(placeholders) == 2 + placeholder_ids = {m["tool_call_id"] for m in placeholders} + assert placeholder_ids == {"call_x", "call_y"} + + def test_empty_tool_calls_list_is_noop(self): + """Assistant message with tool_calls=[] should not trigger repair.""" + messages = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": "Using tools...", + "tool_calls": [], + }, + {"role": "assistant", "content": "Done."}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + assert result == messages + + def test_tool_results_not_immediately_following(self): + """Tool results separated from assistant message by another message type.""" + messages = [ + {"role": "user", "content": "Go"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_gap", "type": "function", "function": {"name": "t", "arguments": "{}"}}, + ], + }, + # A user message appears before the tool result + {"role": "user", "content": "Hurry up"}, + {"role": "tool", "tool_call_id": "call_gap", "content": "Late result"}, + ] + + result = ConversationProcessor._validate_tool_messages(messages) + + # The tool result is not immediately following, so scanner won't find it + # A placeholder should be inserted right after the assistant message + assert len(result) == 5 + assert result[2]["role"] == "tool" + assert result[2]["tool_call_id"] == "call_gap" + assert "did not complete" in result[2]["content"] diff --git a/tests/chat/test_system_prompt.py b/tests/chat/test_system_prompt.py index f1937859..4cbbb56d 100644 --- a/tests/chat/test_system_prompt.py +++ b/tests/chat/test_system_prompt.py @@ -100,6 +100,95 @@ def test_dynamic_prompt_with_none_tools(self, monkeypatch): assert "0 tools" in result +class TestServerToolGroups: + """Tests for server/tool categorization in the system prompt.""" + + def test_no_server_groups_no_section(self): + """Without server_tool_groups, no server section appears.""" + from mcp_cli.chat.system_prompt import generate_system_prompt + + result = generate_system_prompt(tools=[{"name": "t"}], server_tool_groups=None) + assert "CONNECTED SERVERS" not in result + + def test_empty_server_groups_no_section(self): + """Empty server_tool_groups list should not add a section.""" + from mcp_cli.chat.system_prompt import generate_system_prompt + + result = generate_system_prompt(tools=[{"name": "t"}], server_tool_groups=[]) + assert "CONNECTED SERVERS" not in result + + def test_server_groups_appear_in_prompt(self): + """Server groups should be listed in the system prompt.""" + from mcp_cli.chat.system_prompt import generate_system_prompt + + groups = [ + {"name": "stac", "description": "stac MCP server", "tools": ["stac_search", "stac_describe"]}, + {"name": "dem", "description": "dem MCP server", "tools": ["dem_fetch"]}, + {"name": "time", "description": "time MCP server", "tools": ["get_current_time"]}, + ] + result = generate_system_prompt(tools=[{"name": "t"}] * 4, server_tool_groups=groups) + assert "CONNECTED SERVERS & AVAILABLE TOOLS" in result + assert "**stac**" in result + assert "**dem**" in result + assert "**time**" in result + assert "stac_search" in result + assert "dem_fetch" in result + assert "get_current_time" in result + + def test_server_groups_contain_all_tool_names(self): + """Every tool name from the groups should appear in the prompt.""" + from mcp_cli.chat.system_prompt import generate_system_prompt + + groups = [ + {"name": "math", "description": "math server", "tools": ["add", "subtract", "multiply"]}, + ] + result = generate_system_prompt(tools=[{"name": "t"}], server_tool_groups=groups) + assert "add, subtract, multiply" in result + + def test_dynamic_mode_ignores_server_groups(self, monkeypatch): + """Dynamic mode should not include server groups.""" + monkeypatch.setenv("MCP_CLI_DYNAMIC_TOOLS", "1") + from mcp_cli.chat.system_prompt import generate_system_prompt + + groups = [ + {"name": "time", "description": "time server", "tools": ["get_time"]}, + ] + result = generate_system_prompt(tools=[], server_tool_groups=groups) + assert "CONNECTED SERVERS" not in result + assert "TOOL DISCOVERY SYSTEM" in result + + def test_considers_all_servers_guidance(self): + """The prompt should tell the model to consider ALL servers.""" + from mcp_cli.chat.system_prompt import generate_system_prompt + + groups = [ + {"name": "a", "description": "a server", "tools": ["tool_a"]}, + ] + result = generate_system_prompt(tools=[{"name": "t"}], server_tool_groups=groups) + assert "ALL relevant servers" in result + + +class TestBuildServerSection: + """Direct tests for _build_server_section.""" + + def test_none_returns_empty(self): + from mcp_cli.chat.system_prompt import _build_server_section + + assert _build_server_section(None) == "" + + def test_empty_list_returns_empty(self): + from mcp_cli.chat.system_prompt import _build_server_section + + assert _build_server_section([]) == "" + + def test_single_server(self): + from mcp_cli.chat.system_prompt import _build_server_section + + groups = [{"name": "time", "description": "time MCP server", "tools": ["get_time", "convert_tz"]}] + result = _build_server_section(groups) + assert "**time** (time MCP server): get_time, convert_tz" in result + + class TestPrivateFunctionDirectly: """Direct calls to _generate_dynamic_tools_prompt for completeness.""" diff --git a/tests/chat/test_tool_processor.py b/tests/chat/test_tool_processor.py index ed4004f0..95949c6e 100644 --- a/tests/chat/test_tool_processor.py +++ b/tests/chat/test_tool_processor.py @@ -313,10 +313,110 @@ async def test_process_tool_calls_no_tool_manager(): } # Pass as a list to process_tool_calls - should raise RuntimeError + # The finally block still runs, so missing results are filled in. with pytest.raises(RuntimeError, match="No tool manager available"): await processor.process_tool_calls([tool_call]) +@pytest.mark.asyncio +async def test_ensure_all_tool_results_fills_missing(): + """Test that _ensure_all_tool_results adds placeholder results for missing tool_call_ids.""" + result_dict = {"isError": False, "content": "OK"} + tool_manager = DummyToolManager(return_result=result_dict) + context = DummyContext(tool_manager=tool_manager) + ui_manager = DummyUIManager() + processor = ToolProcessor(context, ui_manager) + + # Simulate: assistant message was added but no tool result was added + processor._result_ids_added = set() + tool_calls = [ + ToolCall( + id="call_orphan_1", + type="function", + function=FunctionCall(name="tool_a", arguments='{"x": 1}'), + ), + ToolCall( + id="call_orphan_2", + type="function", + function=FunctionCall(name="tool_b", arguments='{"y": 2}'), + ), + ] + + # Only tool_a got a result + processor._result_ids_added.add("call_orphan_1") + + processor._ensure_all_tool_results(tool_calls) + + # tool_b should now have a placeholder result in the conversation history + tool_results = [ + msg for msg in context.conversation_history + if msg.role.value == "tool" + ] + assert len(tool_results) == 1 + assert tool_results[0].tool_call_id == "call_orphan_2" + assert "interrupted or failed" in tool_results[0].content + + +@pytest.mark.asyncio +async def test_ensure_all_tool_results_noop_when_all_present(): + """Test that _ensure_all_tool_results does nothing when all results are present.""" + result_dict = {"isError": False, "content": "OK"} + tool_manager = DummyToolManager(return_result=result_dict) + context = DummyContext(tool_manager=tool_manager) + ui_manager = DummyUIManager() + processor = ToolProcessor(context, ui_manager) + + processor._result_ids_added = {"call_1", "call_2"} + tool_calls = [ + ToolCall( + id="call_1", + type="function", + function=FunctionCall(name="tool_a", arguments='{}'), + ), + ToolCall( + id="call_2", + type="function", + function=FunctionCall(name="tool_b", arguments='{}'), + ), + ] + + processor._ensure_all_tool_results(tool_calls) + + # No new messages should have been added + assert len(context.conversation_history) == 0 + + +@pytest.mark.asyncio +async def test_process_tool_calls_finally_adds_missing_results(): + """Test that the finally block adds results for tools that never executed. + + When a tool_manager raises RuntimeError, the finally block should still + ensure all tool_call_ids have results, preventing OpenAI 400 errors. + """ + context = DummyContext(stream_manager=None, tool_manager=None) + ui_manager = DummyUIManager() + processor = ToolProcessor(context, ui_manager) + + tool_call = ToolCall( + id="call_no_manager", + type="function", + function=FunctionCall(name="some_tool", arguments='{"a": "b"}'), + ) + + with pytest.raises(RuntimeError): + await processor.process_tool_calls([tool_call]) + + # The finally block should have added a placeholder result + tool_results = [ + msg for msg in context.conversation_history + if msg.role.value == "tool" + ] + assert len(tool_results) >= 1 + assert any( + msg.tool_call_id == "call_no_manager" for msg in tool_results + ) + + @pytest.mark.asyncio async def test_process_tool_calls_exception_in_call(): # Test that an exception raised during execute_tool is caught and an error is recorded. @@ -341,3 +441,146 @@ async def test_process_tool_calls_exception_in_call(): assert len(error_entries) >= 1 # The error should contain the exception message assert any("Simulated execute_tool exception" in e.content for e in error_entries) + + +# --------------------------- +# Tests for orphaned tool_call_id safety net +# --------------------------- + + +class DenyConfirmUIManager(DummyUIManager): + """UI manager that denies tool confirmation.""" + + def do_confirm_tool_execution(self, tool_name, arguments): + return False + + +class FailingInjectContext(DummyContext): + """Context whose inject_tool_message raises after the first N calls.""" + + def __init__(self, tool_manager=None, fail_after=1): + super().__init__(tool_manager=tool_manager) + self._inject_count = 0 + self._fail_after = fail_after + + def inject_tool_message(self, message): + self._inject_count += 1 + if self._inject_count > self._fail_after: + raise RuntimeError("Simulated inject failure") + super().inject_tool_message(message) + + +@pytest.mark.asyncio +async def test_cancelled_tool_still_gets_result_for_remaining(): + """When user cancels confirmation on the 2nd tool, the 2nd tool still gets a result.""" + result_dict = {"isError": False, "content": "OK"} + tool_manager = DummyToolManager(return_result=result_dict) + context = DummyContext(tool_manager=tool_manager) + + call_count = [0] + + class SelectiveDenyUI(DummyUIManager): + """Denies the second tool call.""" + def do_confirm_tool_execution(self, tool_name, arguments): + call_count[0] += 1 + return call_count[0] <= 1 # Allow first, deny second + + ui_manager = SelectiveDenyUI() + processor = ToolProcessor(context, ui_manager) + + tool_calls = [ + ToolCall(id="call_ok", type="function", + function=FunctionCall(name="tool_a", arguments='{}')), + ToolCall(id="call_denied", type="function", + function=FunctionCall(name="tool_b", arguments='{}')), + ] + + await processor.process_tool_calls(tool_calls) + + # Both tool_call_ids must have results (one executed, one cancelled/placeholder) + tool_results = [ + msg for msg in context.conversation_history + if msg.role.value == "tool" + ] + result_ids = {msg.tool_call_id for msg in tool_results} + assert "call_ok" in result_ids or "call_denied" in result_ids + # The key assertion: no orphaned tool_call_ids + assistant_msgs = [ + msg for msg in context.conversation_history + if msg.role.value == "assistant" and msg.tool_calls + ] + for amsg in assistant_msgs: + for tc in amsg.tool_calls: + tc_id = tc.id if hasattr(tc, "id") else tc.get("id") + assert tc_id in result_ids, f"Orphaned tool_call_id: {tc_id}" + + +@pytest.mark.asyncio +async def test_inject_failure_is_caught_by_finally(): + """When inject_tool_message fails, the finally block fills missing results.""" + result_dict = {"isError": False, "content": "OK"} + tool_manager = DummyToolManager(return_result=result_dict) + # fail_after=1 means the assistant message succeeds, but the tool result inject fails + context = FailingInjectContext(tool_manager=tool_manager, fail_after=1) + ui_manager = DummyUIManager() + processor = ToolProcessor(context, ui_manager) + + tool_call = ToolCall( + id="call_inject_fail", type="function", + function=FunctionCall(name="some_tool", arguments='{}'), + ) + + # The tool result inject fails, then execution runs, then finally block fires. + # Since inject keeps failing, the finally block's attempt also fails. + # But the key point is process_tool_calls doesn't crash. + await processor.process_tool_calls([tool_call]) + + # The assistant message should be in history (first inject succeeded) + assert len(context.conversation_history) >= 1 + assert context.conversation_history[0].role.value == "assistant" + + +@pytest.mark.asyncio +async def test_result_ids_reset_between_calls(): + """_result_ids_added is reset on each call to process_tool_calls.""" + result_dict = {"isError": False, "content": "OK"} + tool_manager = DummyToolManager(return_result=result_dict) + context = DummyContext(tool_manager=tool_manager) + ui_manager = DummyUIManager() + processor = ToolProcessor(context, ui_manager) + + # First call + tc1 = ToolCall(id="call_first", type="function", + function=FunctionCall(name="tool_a", arguments='{}')) + await processor.process_tool_calls([tc1]) + assert "call_first" in processor._result_ids_added + + # Second call should start fresh + tc2 = ToolCall(id="call_second", type="function", + function=FunctionCall(name="tool_b", arguments='{}')) + await processor.process_tool_calls([tc2]) + assert "call_second" in processor._result_ids_added + assert "call_first" not in processor._result_ids_added + + +@pytest.mark.asyncio +async def test_successful_batch_tracks_all_ids(): + """All tool_call_ids in a successful batch are tracked in _result_ids_added.""" + result_dict = {"isError": False, "content": "OK"} + tool_manager = DummyToolManager(return_result=result_dict) + context = DummyContext(tool_manager=tool_manager) + ui_manager = DummyUIManager() + processor = ToolProcessor(context, ui_manager) + + tool_calls = [ + ToolCall(id="batch_1", type="function", + function=FunctionCall(name="tool_a", arguments='{}')), + ToolCall(id="batch_2", type="function", + function=FunctionCall(name="tool_b", arguments='{}')), + ToolCall(id="batch_3", type="function", + function=FunctionCall(name="tool_c", arguments='{}')), + ] + + await processor.process_tool_calls(tool_calls) + + assert processor._result_ids_added == {"batch_1", "batch_2", "batch_3"}