From 99ae634a6891319ce8a02cb316398d4e741c5ca6 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 10:00:55 +0000 Subject: [PATCH 01/15] typesafe --- src/fast_agent/acp/protocols.py | 53 ++++ src/fast_agent/acp/slash_commands.py | 195 +++++++------ src/fast_agent/agents/llm_decorator.py | 19 +- src/fast_agent/agents/mcp_agent.py | 153 +++------- .../agents/workflow/iterative_planner.py | 4 +- .../agents/workflow/router_agent.py | 3 +- src/fast_agent/core/fastagent.py | 21 +- src/fast_agent/core/instruction.py | 71 ++++- src/fast_agent/core/instruction_refresh.py | 261 +++++++++++++++++- src/fast_agent/interfaces.py | 9 +- src/fast_agent/mcp/types.py | 26 +- src/fast_agent/tools/elicitation.py | 40 ++- src/fast_agent/tools/shell_runtime.py | 8 +- src/fast_agent/ui/command_payloads.py | 104 +++++++ src/fast_agent/ui/enhanced_prompt.py | 244 ++++++++++------ src/fast_agent/ui/interactive_prompt.py | 113 ++++---- src/fast_agent/ui/streaming.py | 4 +- src/fast_agent/ui/streaming_buffer.py | 12 +- .../core/test_instruction_refresh.py | 195 +++++++++++-- typesafe.md | 119 ++++++++ 20 files changed, 1189 insertions(+), 465 deletions(-) create mode 100644 src/fast_agent/acp/protocols.py create mode 100644 src/fast_agent/ui/command_payloads.py create mode 100644 typesafe.md diff --git a/src/fast_agent/acp/protocols.py b/src/fast_agent/acp/protocols.py new file mode 100644 index 000000000..06651de04 --- /dev/null +++ b/src/fast_agent/acp/protocols.py @@ -0,0 +1,53 @@ +""" +ACP capability Protocols for type-safe isinstance checks. + +These Protocols define optional capabilities that agents may implement. +Use isinstance() checks instead of hasattr() to verify capability support. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from fast_agent.acp.filesystem_runtime import ACPFilesystemRuntime + from fast_agent.acp.terminal_runtime import ACPTerminalRuntime + from fast_agent.workflow_telemetry import PlanTelemetryProvider, WorkflowTelemetryProvider + + +@runtime_checkable +class ShellRuntimeCapable(Protocol): + """Agent that supports external shell runtime injection.""" + + @property + def _shell_runtime_enabled(self) -> bool: ... + + def set_external_runtime(self, runtime: "ACPTerminalRuntime") -> None: ... + + +@runtime_checkable +class FilesystemRuntimeCapable(Protocol): + """Agent that supports external filesystem runtime injection.""" + + def set_filesystem_runtime(self, runtime: "ACPFilesystemRuntime") -> None: ... + + +@runtime_checkable +class InstructionContextCapable(Protocol): + """Agent that supports dynamic instruction context updates.""" + + def set_instruction_context(self, context: dict[str, str]) -> None: ... + + +@runtime_checkable +class WorkflowTelemetryCapable(Protocol): + """Agent that supports workflow telemetry.""" + + workflow_telemetry: "WorkflowTelemetryProvider | None" + + +@runtime_checkable +class PlanTelemetryCapable(Protocol): + """Agent that supports plan telemetry.""" + + plan_telemetry: "PlanTelemetryProvider | None" diff --git a/src/fast_agent/acp/slash_commands.py b/src/fast_agent/acp/slash_commands.py index 38862c151..0c0b368ce 100644 --- a/src/fast_agent/acp/slash_commands.py +++ b/src/fast_agent/acp/slash_commands.py @@ -38,6 +38,7 @@ from fast_agent.mcp.helpers.content_helpers import get_text from fast_agent.mcp.prompts.prompt_load import load_history_into_agent from fast_agent.skills.manager import ( + MarketplaceSkill, candidate_marketplace_urls, fetch_marketplace_skills, fetch_marketplace_skills_with_source, @@ -52,7 +53,7 @@ select_manifest_by_name_or_index, select_skill_by_name_or_index, ) -from fast_agent.skills.registry import format_skills_for_prompt +from fast_agent.skills.registry import SkillManifest, format_skills_for_prompt from fast_agent.types.conversation_summary import ConversationSummary from fast_agent.utils.time import format_duration @@ -82,8 +83,23 @@ def instruction(self) -> str | None: ... @runtime_checkable -class InstructionRefreshAgent(InstructionAwareAgent, Protocol): - async def rebuild_instruction_templates(self) -> None: ... +class ACPCommandAllowlistProvider(Protocol): + @property + def acp_session_commands_allowlist(self) -> set[str] | None: ... + + +@runtime_checkable +class ParallelAgentProtocol(Protocol): + @property + def fan_out_agents(self) -> list[AgentProtocol] | None: ... + + @property + def fan_in_agent(self) -> AgentProtocol | None: ... + + +@runtime_checkable +class HfDisplayInfoProvider(Protocol): + def get_hf_display_info(self) -> dict[str, Any]: ... class SlashCommandHandler: @@ -193,20 +209,16 @@ def _get_allowed_session_commands(self) -> dict[str, AvailableCommand]: Return session-level commands filtered by the current agent's policy. By default, all session commands are available. ACP-aware agents can restrict - session commands (e.g. Setup/wizard flows) by defining either: - - `acp_session_commands_allowlist: set[str] | None` attribute, or - - `acp_session_commands_allowlist() -> set[str] | None` method + session commands (e.g. Setup/wizard flows) by defining a + `acp_session_commands_allowlist: set[str] | None` attribute. """ agent = self._get_current_agent() if not isinstance(agent, ACPAwareProtocol): return self._session_commands - allowlist = getattr(agent, "acp_session_commands_allowlist", None) - if callable(allowlist): - try: - allowlist = allowlist() - except Exception: - allowlist = None + allowlist = None + if isinstance(agent, ACPCommandAllowlistProvider): + allowlist = agent.acp_session_commands_allowlist if allowlist is None: return self._session_commands @@ -358,9 +370,7 @@ async def _handle_status(self, arguments: str | None = None) -> str: agent = self._get_current_agent() # Check if this is a PARALLEL agent - is_parallel_agent = ( - agent and hasattr(agent, "agent_type") and agent.agent_type == AgentType.PARALLEL - ) + is_parallel_agent = agent is not None and agent.agent_type == AgentType.PARALLEL # For non-parallel agents, extract standard model info model_name = "unknown" @@ -374,9 +384,7 @@ async def _handle_status(self, arguments: str | None = None) -> str: if model_info: model_name = model_info.name model_provider = str(model_info.provider.value) - model_provider_display = getattr( - model_info.provider, "display_name", model_provider - ) + model_provider_display = model_info.provider.display_name if model_info.context_window: context_window = f"{model_info.context_window} tokens" capability_parts = [] @@ -448,19 +456,20 @@ async def _handle_status(self, arguments: str | None = None) -> str: status_lines.append("") # Display fan-out agents - if hasattr(agent, "fan_out_agents") and agent.fan_out_agents: - status_lines.append(f"### Fan-Out Agents ({len(agent.fan_out_agents)})") - for idx, fan_out_agent in enumerate(agent.fan_out_agents, 1): - agent_name = getattr(fan_out_agent, "name", f"agent-{idx}") + fan_out_agents = ( + agent.fan_out_agents if isinstance(agent, ParallelAgentProtocol) else None + ) + if fan_out_agents: + status_lines.append(f"### Fan-Out Agents ({len(fan_out_agents)})") + for idx, fan_out_agent in enumerate(fan_out_agents, 1): + agent_name = fan_out_agent.name status_lines.append(f"**{idx}. {agent_name}**") # Get model info for this fan-out agent if fan_out_agent.llm: model_info = ModelInfo.from_llm(fan_out_agent.llm) if model_info: - provider_display = getattr( - model_info.provider, "display_name", str(model_info.provider.value) - ) + provider_display = model_info.provider.display_name status_lines.append(f" - Provider: {provider_display}") status_lines.append(f" - Model: {model_info.name}") if model_info.context_window: @@ -476,18 +485,16 @@ async def _handle_status(self, arguments: str | None = None) -> str: status_lines.append("") # Display fan-in agent - if hasattr(agent, "fan_in_agent") and agent.fan_in_agent: - fan_in_agent = agent.fan_in_agent - fan_in_name = getattr(fan_in_agent, "name", "aggregator") + fan_in_agent = agent.fan_in_agent if isinstance(agent, ParallelAgentProtocol) else None + if fan_in_agent: + fan_in_name = fan_in_agent.name status_lines.append(f"### Fan-In Agent: {fan_in_name}") # Get model info for fan-in agent if fan_in_agent.llm: model_info = ModelInfo.from_llm(fan_in_agent.llm) if model_info: - provider_display = getattr( - model_info.provider, "display_name", str(model_info.provider.value) - ) + provider_display = model_info.provider.display_name status_lines.append(f" - Provider: {provider_display}") status_lines.append(f" - Model: {model_info.name}") if model_info.context_window: @@ -509,10 +516,9 @@ async def _handle_status(self, arguments: str | None = None) -> str: provider_line = f"{model_provider_display} ({model_provider})" # For HuggingFace, add the routing provider info - if agent and agent.llm: - get_hf_info = getattr(agent.llm, "get_hf_display_info", None) - if callable(get_hf_info): - hf_info = get_hf_info() + if agent and agent.llm and isinstance(agent.llm, HfDisplayInfoProvider): + hf_info = agent.llm.get_hf_display_info() + if hf_info: hf_provider = hf_info.get("provider", "auto-routing") provider_line = f"{model_provider_display} ({model_provider}) / {hf_provider}" @@ -529,7 +535,7 @@ async def _handle_status(self, arguments: str | None = None) -> str: # Add conversation statistics status_lines.append( - f"## Conversation Statistics ({getattr(agent, 'name', self.current_agent_name) if agent else 'Unknown'})" + f"## Conversation Statistics ({agent.name if agent else 'Unknown'})" ) uptime_seconds = max(time.time() - self._created_at, 0.0) @@ -552,12 +558,6 @@ async def _handle_status_system(self) -> str: if error: return error - if isinstance(agent, InstructionRefreshAgent): - try: - await agent.rebuild_instruction_templates() - except Exception: - pass - system_prompt = agent.instruction if isinstance(agent, InstructionAwareAgent) else None if not system_prompt: return "\n".join( @@ -729,10 +729,7 @@ async def _handle_skills_registry(self, argument: str) -> str: # Get configured registries from settings settings = get_settings() - skills_settings = getattr(settings, "skills", None) - configured_urls: list[str] = [] - if skills_settings: - configured_urls = getattr(skills_settings, "marketplace_urls", None) or [] + configured_urls = settings.skills.marketplace_urls or [] if not argument: current = get_marketplace_url(settings) @@ -799,8 +796,7 @@ async def _handle_skills_registry(self, argument: str) -> str: ) # Update only the active registry, preserve the configured list - if skills_settings is not None: - skills_settings.marketplace_url = resolved_url + settings.skills.marketplace_url = resolved_url display_url = format_marketplace_display_url(resolved_url) if candidates: @@ -834,6 +830,7 @@ async def _handle_skills_add(self, argument: str) -> str: agent, error = self._get_current_agent_or_error("# skills add") if error: return error + assert agent is not None tool_call_id = self._build_tool_call_id() await self._send_skills_update( @@ -956,6 +953,7 @@ async def _handle_skills_remove(self, argument: str) -> str: agent, error = self._get_current_agent_or_error("# skills remove") if error: return error + assert agent is not None await self._refresh_agent_skills(agent) @@ -986,7 +984,7 @@ async def _refresh_agent_skills(self, agent: AgentProtocol) -> None: skill_registry=registry, ) - def _format_local_skills(self, manifests: list[Any], manager_dir: Path) -> str: + def _format_local_skills(self, manifests: list[SkillManifest], manager_dir: Path) -> str: lines = ["# skills", "", f"Directory: `{manager_dir}`", ""] if not manifests: lines.append("No skills available in the manager directory.") @@ -1001,12 +999,12 @@ def _format_local_skills(self, manifests: list[Any], manager_dir: Path) -> str: lines.append("Change skills registry with `/skills registry `.\n") return "\n".join(lines) - def _format_local_list(self, manifests: list[Any]) -> list[str]: + def _format_local_list(self, manifests: list[SkillManifest]) -> list[str]: lines: list[str] = [] for index, manifest in enumerate(manifests, 1): - name = getattr(manifest, "name", "") - description = getattr(manifest, "description", "") - path = Path(getattr(manifest, "path", Path())) + name = manifest.name + description = manifest.description + path = manifest.path source_path = path.parent if path.is_file() else path try: display_path = source_path.relative_to(Path.cwd()) @@ -1020,12 +1018,12 @@ def _format_local_list(self, manifests: list[Any]) -> list[str]: lines.append(f" - source: `{display_path}`") return lines - def _format_marketplace_list(self, marketplace: list[Any]) -> list[str]: + def _format_marketplace_list(self, marketplace: list[MarketplaceSkill]) -> list[str]: lines: list[str] = [] - current_bundle = None + current_bundle: str | None = None for index, entry in enumerate(marketplace, 1): - bundle_name = getattr(entry, "bundle_name", None) - bundle_description = getattr(entry, "bundle_description", None) + bundle_name = entry.bundle_name + bundle_description = entry.bundle_description if bundle_name and bundle_name != current_bundle: current_bundle = bundle_name if lines: @@ -1043,16 +1041,16 @@ def _format_marketplace_list(self, marketplace: list[Any]) -> list[str]: lines.append(f" - source: [link]({entry.source_url})") return lines - def _format_repo_label(self, entry: Any) -> str | None: - repo_url = getattr(entry, "repo_url", None) + def _format_repo_label(self, entry: MarketplaceSkill) -> str | None: + repo_url = entry.repo_url if not repo_url: return None - repo_ref = getattr(entry, "repo_ref", None) + repo_ref = entry.repo_ref if repo_ref: return f"{repo_url}@{repo_ref}" return repo_url - def _get_marketplace_repo_hint(self, marketplace: list[Any]) -> str | None: + def _get_marketplace_repo_hint(self, marketplace: list[MarketplaceSkill]) -> str | None: if not marketplace: return None return self._format_repo_label(marketplace[0]) @@ -1070,7 +1068,9 @@ async def _send_skills_update( message: str | None = None, start: bool = False, ) -> None: - acp = getattr(agent, "acp", None) + if not isinstance(agent, ACPAwareProtocol): + return + acp = agent.acp if not acp: return try: @@ -1166,6 +1166,7 @@ async def _handle_save(self, arguments: str | None = None) -> str: ) if error: return error + assert agent is not None filename = arguments.strip() if arguments and arguments.strip() else None @@ -1200,6 +1201,7 @@ async def _handle_load(self, arguments: str | None = None) -> str: ) if error: return error + assert agent is not None filename = arguments.strip() if arguments and arguments.strip() else None @@ -1221,11 +1223,9 @@ async def _handle_load(self, arguments: str | None = None) -> str: "", f"File not found: `{filename}`", ] - ) + ) try: - if hasattr(agent, "rebuild_instruction_templates"): - await agent.rebuild_instruction_templates() load_history_into_agent(agent, file_path) except Exception as exc: return "\n".join( @@ -1237,9 +1237,7 @@ async def _handle_load(self, arguments: str | None = None) -> str: ] ) - message_count = len(agent.message_history) if hasattr(agent, "message_history") else 0 - if hasattr(agent, "rebuild_instruction_templates"): - await agent.rebuild_instruction_templates() + message_count = len(agent.message_history) return "\n".join( [ @@ -1267,19 +1265,12 @@ def _handle_clear_all(self) -> str: ) if error: return error + assert agent is not None try: - history = getattr(agent, "message_history", None) - original_count = len(history) if isinstance(history, list) else None - - cleared = False - clear_method = getattr(agent, "clear", None) - if callable(clear_method): - clear_method() - cleared = True - elif isinstance(history, list): - history.clear() - cleared = True + original_count = len(agent.message_history) + agent.clear() + cleared = True except Exception as exc: return "\n".join( [ @@ -1323,16 +1314,12 @@ def _handle_clear_last(self) -> str: ) if error: return error + assert agent is not None try: - removed = None - pop_method = getattr(agent, "pop_last_message", None) - if callable(pop_method): - removed = pop_method() - else: - history = getattr(agent, "message_history", None) - if isinstance(history, list) and history: - removed = history.pop() + removed = agent.pop_last_message() + if removed is None and agent.message_history: + removed = agent.message_history.pop() except Exception as exc: return "\n".join( [ @@ -1352,7 +1339,7 @@ def _handle_clear_last(self) -> str: ] ) - role = getattr(removed, "role", "message") + role = removed.role if removed else "message" return "\n".join( [ heading, @@ -1361,9 +1348,9 @@ def _handle_clear_last(self) -> str: ] ) - def _get_conversation_stats(self, agent) -> list[str]: + def _get_conversation_stats(self, agent: AgentProtocol | None) -> list[str]: """Get conversation statistics from the agent's message history.""" - if not agent or not hasattr(agent, "message_history"): + if not agent: return [ "- Turns: 0", "- Tool Calls: 0", @@ -1420,17 +1407,19 @@ def _get_conversation_stats(self, agent) -> list[str]: f"- Context Used: error ({e})", ] - def _get_error_handling_report(self, agent, max_entries: int = 3) -> list[str]: + def _get_error_handling_report( + self, agent: AgentProtocol | None, max_entries: int = 3 + ) -> list[str]: """Summarize error channel availability and recent entries.""" channel_label = f"Error Channel: {FAST_AGENT_ERROR_CHANNEL}" - if not agent or not hasattr(agent, "message_history"): + if not agent: return ["_No errors recorded_"] recent_entries: list[str] = [] - history = getattr(agent, "message_history", []) or [] + history = agent.message_history for message in reversed(history): - channels = getattr(message, "channels", None) or {} + channels = message.channels or {} channel_blocks = channels.get(FAST_AGENT_ERROR_CHANNEL) if not channel_blocks: continue @@ -1460,7 +1449,7 @@ def _get_error_handling_report(self, agent, max_entries: int = 3) -> list[str]: return ["_No errors recorded_"] - def _get_warning_report(self, agent, max_entries: int = 5) -> list[str]: + def _get_warning_report(self, agent: AgentProtocol | None, max_entries: int = 5) -> list[str]: warnings: list[str] = [] if isinstance(agent, WarningAwareAgent): warnings.extend(agent.warnings) @@ -1485,10 +1474,10 @@ def _get_warning_report(self, agent, max_entries: int = 5) -> list[str]: lines.append(f"- ... ({len(cleaned) - max_entries} more)") return lines - def _context_usage_line(self, summary: ConversationSummary, agent) -> str: + def _context_usage_line(self, summary: ConversationSummary, agent: AgentProtocol) -> str: """Generate a context usage line with token estimation and fallbacks.""" # Prefer usage accumulator when available (matches enhanced/interactive prompt display) - usage = getattr(agent, "usage_accumulator", None) + usage = agent.usage_accumulator if usage: window = usage.context_window_size tokens = usage.current_context_tokens @@ -1501,7 +1490,7 @@ def _context_usage_line(self, summary: ConversationSummary, agent) -> str: # Fallback to tokenizing the actual conversation text token_count, char_count = self._estimate_tokens(summary, agent) - model_info = ModelInfo.from_llm(agent.llm) if getattr(agent, "llm", None) else None + model_info = ModelInfo.from_llm(agent.llm) if agent.llm else None if model_info and model_info.context_window: percentage = ( (token_count / model_info.context_window) * 100 @@ -1514,11 +1503,13 @@ def _context_usage_line(self, summary: ConversationSummary, agent) -> str: token_text = f"~{token_count:,} tokens" if token_count else "~0 tokens" return f"- Context Used: {char_count:,} chars ({token_text} est.)" - def _estimate_tokens(self, summary: ConversationSummary, agent) -> tuple[int, int]: + def _estimate_tokens( + self, summary: ConversationSummary, agent: AgentProtocol + ) -> tuple[int, int]: """Estimate tokens and return (tokens, characters) for the conversation history.""" text_parts: list[str] = [] for message in summary.messages: - for content in getattr(message, "content", []) or []: + for content in message.content: text = get_text(content) if text: text_parts.append(text) @@ -1529,9 +1520,9 @@ def _estimate_tokens(self, summary: ConversationSummary, agent) -> tuple[int, in return 0, 0 model_name = None - llm = getattr(agent, "llm", None) + llm = agent.llm if llm: - model_name = getattr(llm, "model_name", None) + model_name = llm.model_name token_count = self._count_tokens_with_tiktoken(combined, model_name) return token_count, char_count diff --git a/src/fast_agent/agents/llm_decorator.py b/src/fast_agent/agents/llm_decorator.py index 4d4cad1b6..e0f9f2698 100644 --- a/src/fast_agent/agents/llm_decorator.py +++ b/src/fast_agent/agents/llm_decorator.py @@ -179,7 +179,7 @@ def __init__( self._context = context self._name = self.config.name self._tracer = trace.get_tracer(__name__) - self.instruction = self.config.instruction + self._instruction = self.config.instruction # Agent-owned conversation state (PromptMessageExtended only) self._message_history: list[PromptMessageExtended] = [] @@ -214,6 +214,17 @@ async def initialize(self) -> None: async def shutdown(self) -> None: self.initialized = False + @property + def instruction(self) -> str: + """Return the agent's instruction/system prompt.""" + return self._instruction + + def set_instruction(self, instruction: str) -> None: + """Set the agent's instruction/system prompt.""" + self._instruction = instruction + if self._default_request_params: + self._default_request_params.systemPrompt = instruction + @property def agent_type(self) -> AgentType: """ @@ -1001,9 +1012,9 @@ def append_history(self, messages: list[PromptMessageExtended] | None) -> None: def pop_last_message(self) -> PromptMessageExtended | None: """Remove and return the most recent message from the conversation history.""" - if self.llm: - return self.llm.pop_last_message() - return None + if not self._message_history: + return None + return self._message_history.pop() @property def usage_accumulator(self) -> UsageAccumulator | None: diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index e7dbff11a..3067483e2 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -38,18 +38,16 @@ from fast_agent.agents.tool_agent import ToolAgent from fast_agent.constants import HUMAN_INPUT_TOOL_NAME from fast_agent.core.exceptions import PromptExitError -from fast_agent.core.instruction import InstructionBuilder from fast_agent.core.logging.logger import get_logger from fast_agent.interfaces import FastAgentLLMProtocol from fast_agent.mcp.common import ( - create_namespaced_name, get_resource_name, get_server_name, is_namespaced_name, ) from fast_agent.mcp.mcp_aggregator import MCPAggregator, NamespacedTool, ServerStatus from fast_agent.skills import SkillManifest -from fast_agent.skills.registry import SkillRegistry, format_skills_for_prompt +from fast_agent.skills.registry import SkillRegistry from fast_agent.tools.elicitation import ( get_elicitation_tool, run_elicitation_form, @@ -106,7 +104,7 @@ def __init__( # Store the original template - resolved instruction set after build() self._instruction_template = self.config.instruction - self.instruction = self.config.instruction # Will be replaced by builder output + self._instruction = self.config.instruction # Will be replaced by builder output self.executor = context.executor if context else None self.logger = get_logger(f"{__name__}.{self._name}") manifests: list[SkillManifest] = list(getattr(self.config, "skill_manifests", []) or []) @@ -179,12 +177,8 @@ def __init__( if self._shell_runtime_enabled: self._shell_runtime.announce() - # Create instruction builder with dynamic resolvers - self._instruction_builder = InstructionBuilder(self._instruction_template or "") - self._instruction_builder.set_resolver( - "serverInstructions", self._resolve_server_instructions - ) - self._instruction_builder.set_resolver("agentSkills", self._resolve_agent_skills) + # Store instruction context for template resolution + self._instruction_context: dict[str, str] = {} # Allow external runtime injection (e.g., for ACP terminal support) self._external_runtime = None @@ -269,6 +263,26 @@ def aggregator(self) -> MCPAggregator: """Expose the MCP aggregator for UI integrations.""" return self._aggregator + @property + def instruction_template(self) -> str: + """The original instruction template with placeholders.""" + return self._instruction_template or "" + + @property + def instruction_context(self) -> dict[str, str]: + """Context values for instruction template resolution.""" + return self._instruction_context + + @property + def skill_manifests(self) -> list[SkillManifest]: + """List of skill manifests configured for this agent.""" + return self._skill_manifests + + @property + def has_filesystem_runtime(self) -> bool: + """Whether filesystem runtime is available (affects skill tool names).""" + return self._filesystem_runtime is not None + @property def initialized(self) -> bool: """Check if both the agent and aggregator are initialized.""" @@ -285,49 +299,31 @@ async def _apply_instruction_templates(self) -> None: Apply template substitution to the instruction, including server instructions. This is called during initialization after servers are connected. """ - if not self._instruction_builder.template: + from fast_agent.core.instruction_refresh import build_instruction + + if not self._instruction_template: return - # Build the instruction using the InstructionBuilder - self.instruction = await self._instruction_builder.build() + # Build the instruction using the central helper + new_instruction = await build_instruction( + self._instruction_template, + aggregator=self._aggregator, + skill_manifests=self._skill_manifests, + has_filesystem_runtime=self.has_filesystem_runtime, + context=self._instruction_context, + ) + self.set_instruction(new_instruction) # Warn if skills configured but placeholder missing - if self._skill_manifests and "{{agentSkills}}" not in self._instruction_builder.template: + if self._skill_manifests and "{{agentSkills}}" not in self._instruction_template: warning_message = ( "Agent skills are configured but the system prompt does not include {{agentSkills}}. " "Skill descriptions will not be added to the system prompt." ) self._record_warning(warning_message) - # Update default request params to match - if self._default_request_params: - self._default_request_params.systemPrompt = self.instruction - self.logger.debug(f"Applied instruction templates for agent {self._name}") - # ───────────────────────────────────────────────────────────────────────── - # Instruction Resolvers (for InstructionBuilder) - # ───────────────────────────────────────────────────────────────────────── - - async def _resolve_server_instructions(self) -> str: - """Resolver for {{serverInstructions}} placeholder.""" - try: - instructions_data = await self._aggregator.get_server_instructions() - return self._format_server_instructions(instructions_data) - except Exception as e: - self.logger.warning(f"Failed to get server instructions: {e}") - return "" - - async def _resolve_agent_skills(self) -> str: - """Resolver for {{agentSkills}} placeholder.""" - # Determine which tool to reference in the preamble - # ACP context provides read_text_file; otherwise use read_skill - if self._filesystem_runtime and hasattr(self._filesystem_runtime, "tools"): - read_tool_name = "read_text_file" - else: - read_tool_name = "read_skill" - return format_skills_for_prompt(self._skill_manifests, read_tool_name=read_tool_name) - def set_skill_manifests(self, manifests: Sequence[SkillManifest]) -> None: self._skill_manifests = list(manifests) self._skill_map = {manifest.name: manifest for manifest in self._skill_manifests} @@ -353,86 +349,17 @@ def warnings(self) -> list[str]: def set_instruction_context(self, context: dict[str, str]) -> None: """ - Set session-level context variables on the instruction builder. + Set session-level context variables for instruction template resolution. This should be called when an ACP session is established to provide - variables like {{env}}, {{workspaceRoot}}, {{agentSkills}} etc. that - are resolved per-session. + variables like {{env}}, {{workspaceRoot}} etc. that are resolved per-session. Args: context: Dict mapping placeholder names to values (e.g., {"env": "...", "workspaceRoot": "/path"}) """ - self._instruction_builder.set_many(context) + self._instruction_context.update(context) self.logger.debug(f"Set instruction context for agent {self._name}: {list(context.keys())}") - def _format_server_instructions( - self, instructions_data: dict[str, tuple[str | None, list[str]]] - ) -> str: - """ - Format server instructions with XML tags and tool lists. - - Args: - instructions_data: Dict mapping server name to (instructions, tool_names) - - Returns: - Formatted string with server instructions - """ - if not instructions_data: - return "" - - formatted_parts = [] - for server_name, (instructions, tool_names) in instructions_data.items(): - # Skip servers with no instructions - if instructions is None: - continue - - # Format tool names with server prefix using the new namespacing convention - prefixed_tools = [create_namespaced_name(server_name, tool) for tool in tool_names] - tools_list = ", ".join(prefixed_tools) if prefixed_tools else "No tools available" - - formatted_parts.append( - f'\n' - f"{tools_list}\n" - f"\n{instructions}\n\n" - f"" - ) - - if formatted_parts: - return "\n\n".join(formatted_parts) - return "" - - async def rebuild_instruction_templates(self) -> None: - """ - Rebuild instruction from template with fresh source values. - - Call this method after connecting new MCP servers (e.g., via /connect command) - to update the system prompt with fresh {{serverInstructions}}. - - The InstructionBuilder re-resolves all dynamic sources (serverInstructions, - agentSkills, etc.) each time build() is called. - """ - if not self._instruction_builder.template: - return - - # Rebuild using the instruction builder (resolvers are called fresh) - self.instruction = await self._instruction_builder.build() - - # Update default request params to match - if self._default_request_params: - self._default_request_params.systemPrompt = self.instruction - - # Invalidate ACP session caches if running in ACP mode - if self.context and hasattr(self.context, "acp") and self.context.acp: - try: - await self.context.acp.invalidate_instruction_cache( - agent_name=self._name, - new_instruction=self.instruction, - ) - except Exception as e: - self.logger.warning(f"Failed to invalidate ACP instruction cache: {e}") - - self.logger.info(f"Rebuilt instruction templates for agent {self._name}") - async def __call__( self, message: Union[ diff --git a/src/fast_agent/agents/workflow/iterative_planner.py b/src/fast_agent/agents/workflow/iterative_planner.py index 590736810..73a183065 100644 --- a/src/fast_agent/agents/workflow/iterative_planner.py +++ b/src/fast_agent/agents/workflow/iterative_planner.py @@ -235,8 +235,8 @@ async def initialize(self) -> None: # Replace {{agents}} placeholder in the system prompt template system_prompt = self.config.instruction.replace("{{agents}}", agents_str) - # Update the config instruction with the formatted system prompt - self.instruction = system_prompt + # Update the instruction with the formatted system prompt + self.set_instruction(system_prompt) # Initialize the base agent with the updated system prompt await super().initialize() diff --git a/src/fast_agent/agents/workflow/router_agent.py b/src/fast_agent/agents/workflow/router_agent.py index 7d39dc869..ca4808c33 100644 --- a/src/fast_agent/agents/workflow/router_agent.py +++ b/src/fast_agent/agents/workflow/router_agent.py @@ -127,8 +127,7 @@ async def initialize(self) -> None: combined_system_prompt = ( ROUTING_SYSTEM_INSTRUCTION + "\n\n" + complete_routing_instruction ) - self._default_request_params.systemPrompt = combined_system_prompt - self.instruction = combined_system_prompt + self.set_instruction(combined_system_prompt) async def shutdown(self) -> None: """Shutdown the router and all agents.""" diff --git a/src/fast_agent/core/fastagent.py b/src/fast_agent/core/fastagent.py index f4fb22e4d..8a997b978 100644 --- a/src/fast_agent/core/fastagent.py +++ b/src/fast_agent/core/fastagent.py @@ -895,25 +895,8 @@ def _apply_instruction_context( if resolved == template: continue - agent.instruction = resolved - - # Note: We intentionally do NOT modify config.instruction here. - # The config should preserve the original template so that - # downstream logic (like MCP display) can check for template - # variables like {{serverInstructions}}. - - request_params = getattr(agent, "_default_request_params", None) - if request_params is not None: - request_params.systemPrompt = resolved - - # TODO -- find a cleaner way of doing this - # Keep any attached LLM in sync so the provider sees the resolved prompt - llm = getattr(agent, "_llm", None) - if llm is not None: - if getattr(llm, "default_request_params", None) is not None: - llm.default_request_params.systemPrompt = resolved - if hasattr(llm, "instruction"): - llm.instruction = resolved + # Use set_instruction() which handles syncing request_params and LLM + agent.set_instruction(resolved) def _apply_skills_to_agent_configs(self, default_skills: list[SkillManifest]) -> None: self._default_skill_manifests = list(default_skills) diff --git a/src/fast_agent/core/instruction.py b/src/fast_agent/core/instruction.py index bb54ef568..17b5b0f13 100644 --- a/src/fast_agent/core/instruction.py +++ b/src/fast_agent/core/instruction.py @@ -5,20 +5,33 @@ from templates with placeholder substitution. Sources can be static values or dynamic resolvers that are called at build time. +Built-in placeholders (automatically resolved): + {{currentDate}} - Current date in "17 December 2025" format + {{hostPlatform}} - Platform info (e.g., "Linux-6.6.0-x86_64") + {{pythonVer}} - Python version (e.g., "3.12.0") + {{url:https://...}} - Fetches content from URL + {{file:path}} - Reads file content (requires workspaceRoot) + {{file_silent:path}} - Reads file, empty string if missing + +Context placeholders (set by caller): + {{workspaceRoot}} - Working directory + {{env}} - Environment description + {{serverInstructions}} - MCP server instructions + {{agentSkills}} - Agent skill descriptions + Usage: - builder = InstructionBuilder(template="You are helpful. {{serverInstructions}}") - builder.set("currentDate", "17 Dec 2025") + builder = InstructionBuilder(template="You are helpful. {{currentDate}}") + builder.set("workspaceRoot", "/path/to/workspace") builder.set_resolver("serverInstructions", fetch_server_instructions) instruction = await builder.build() - - # When sources change, rebuild: - new_instruction = await builder.build() """ from __future__ import annotations +import platform import re +from datetime import datetime from pathlib import Path from typing import Awaitable, Callable @@ -30,6 +43,21 @@ Resolver = Callable[[], Awaitable[str]] +def _get_current_date() -> str: + """Return current date in human-readable format.""" + return datetime.now().strftime("%d %B %Y") + + +def _get_host_platform() -> str: + """Return platform information.""" + return platform.platform() + + +def _get_python_version() -> str: + """Return Python version.""" + return platform.python_version() + + def _fetch_url_content(url: str) -> str: """ Fetch content from a URL. @@ -55,16 +83,27 @@ class InstructionBuilder: Builds instruction strings from templates with placeholder substitution. The builder supports two types of sources: - - Static: String values set once (e.g., currentDate) + - Static: String values set once (e.g., workspaceRoot) - Dynamic: Async resolvers called each time build() is invoked (e.g., serverInstructions) - Placeholder syntax: {{name}} for simple placeholders + Built-in placeholders are automatically resolved unless overridden: + - {{currentDate}} - Current date + - {{hostPlatform}} - Platform info + - {{pythonVer}} - Python version + Special patterns: - {{url:https://...}} - Fetches content from URL (resolved at build time) - {{file:path}} - Reads file content relative to workspace (requires workspaceRoot) - {{file_silent:path}} - Like file: but returns empty string if missing """ + # Built-in values that are automatically available + _BUILTINS: dict[str, Callable[[], str]] = { + "currentDate": _get_current_date, + "hostPlatform": _get_host_platform, + "pythonVer": _get_python_version, + } + def __init__(self, template: str): """ Initialize the builder with a template string. @@ -141,8 +180,9 @@ async def build(self) -> str: 1. {{url:...}} patterns (fetch from URL) 2. {{file:...}} patterns (read from file, requires workspaceRoot set) 3. {{file_silent:...}} patterns (read from file, empty if missing) - 4. Static values - 5. Dynamic resolvers + 4. Built-in values (currentDate, hostPlatform, pythonVer) + 5. Static values (override built-ins if set) + 6. Dynamic resolvers Returns: The fully resolved instruction string @@ -158,13 +198,20 @@ async def build(self) -> str: # 3. Resolve {{file_silent:...}} patterns (returns empty if missing) result = self._resolve_file_patterns(result, silent=True) - # 4. Apply static values + # 4. Apply built-in values (can be overridden by static values) + for placeholder, value_fn in self._BUILTINS.items(): + if placeholder not in self._static: # Allow override + pattern = f"{{{{{placeholder}}}}}" + if pattern in result: + result = result.replace(pattern, value_fn()) + + # 5. Apply static values for placeholder, value in self._static.items(): pattern = f"{{{{{placeholder}}}}}" if pattern in result: result = result.replace(pattern, value) - # 5. Resolve dynamic values + # 6. Resolve dynamic values for placeholder, resolver in self._resolvers.items(): pattern = f"{{{{{placeholder}}}}}" if pattern in result: @@ -259,7 +306,7 @@ def get_unresolved_placeholders(self) -> set[str]: Set of placeholder names without sources """ all_placeholders = self.get_placeholders() - registered = set(self._static.keys()) | set(self._resolvers.keys()) + registered = set(self._static.keys()) | set(self._resolvers.keys()) | set(self._BUILTINS.keys()) return all_placeholders - registered def copy(self) -> "InstructionBuilder": diff --git a/src/fast_agent/core/instruction_refresh.py b/src/fast_agent/core/instruction_refresh.py index fe65cc696..d93975bf4 100644 --- a/src/fast_agent/core/instruction_refresh.py +++ b/src/fast_agent/core/instruction_refresh.py @@ -1,16 +1,206 @@ +""" +Instruction building and refresh utilities. + +This module provides the central logic for building agent instructions from +templates. It consolidates instruction building that was previously spread +across McpAgent and other modules. + +The InstructionBuilder handles template resolution (placeholders like {{currentDate}}, +{{file:path}}, etc.), while this module provides higher-level functions that: +- Gather data from agent sources (MCP servers, skills, etc.) +- Build the complete instruction using InstructionBuilder +- Set the instruction on the agent via set_instruction() +""" + from __future__ import annotations import asyncio from dataclasses import dataclass -from typing import Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping, Protocol, Sequence, runtime_checkable from weakref import WeakKeyDictionary +from fast_agent.core.instruction import InstructionBuilder +from fast_agent.core.logging.logger import get_logger +from fast_agent.mcp.common import create_namespaced_name + +if TYPE_CHECKING: + from fast_agent.mcp.mcp_aggregator import MCPAggregator + from fast_agent.skills import SkillManifest + from fast_agent.skills.registry import SkillRegistry + +logger = get_logger(__name__) + + +# ───────────────────────────────────────────────────────────────────────────── +# Protocols +# ───────────────────────────────────────────────────────────────────────────── + + +@runtime_checkable +class InstructionCapable(Protocol): + """Protocol for agents that support instruction get/set.""" + + @property + def instruction(self) -> str: ... + + def set_instruction(self, instruction: str) -> None: ... + + +@runtime_checkable +class McpInstructionCapable(InstructionCapable, Protocol): + """Protocol for MCP agents that support full instruction refresh.""" + + @property + def instruction_template(self) -> str: ... + + @property + def instruction_context(self) -> dict[str, str]: ... + + @property + def aggregator(self) -> "MCPAggregator": ... + + @property + def skill_manifests(self) -> Sequence["SkillManifest"]: ... + + @property + def skill_registry(self) -> "SkillRegistry | None": ... + + @skill_registry.setter + def skill_registry(self, value: "SkillRegistry | None") -> None: ... + + def set_skill_manifests(self, manifests: Sequence["SkillManifest"]) -> None: ... + + def set_instruction_context(self, context: dict[str, str]) -> None: ... + + @property + def has_filesystem_runtime(self) -> bool: ... + + +# ───────────────────────────────────────────────────────────────────────────── +# Instruction Building +# ───────────────────────────────────────────────────────────────────────────── + + +def format_server_instructions( + instructions_data: dict[str, tuple[str | None, list[str]]] +) -> str: + """ + Format server instructions with XML tags and tool lists. + + Args: + instructions_data: Dict mapping server name to (instructions, tool_names) + + Returns: + Formatted string with server instructions + """ + if not instructions_data: + return "" + + formatted_parts = [] + for server_name, (instructions, tool_names) in instructions_data.items(): + if instructions is None: + continue + + prefixed_tools = [create_namespaced_name(server_name, tool) for tool in tool_names] + tools_list = ", ".join(prefixed_tools) if prefixed_tools else "No tools available" + + formatted_parts.append( + f'\n' + f"{tools_list}\n" + f"\n{instructions}\n\n" + f"" + ) + + return "\n\n".join(formatted_parts) if formatted_parts else "" + + +def format_agent_skills( + manifests: Sequence["SkillManifest"], + has_filesystem_runtime: bool = False, +) -> str: + """ + Format skill manifests for inclusion in the instruction. + + Args: + manifests: List of skill manifests + has_filesystem_runtime: Whether filesystem runtime is available (affects tool name) + + Returns: + Formatted skills text + """ + from fast_agent.skills.registry import format_skills_for_prompt + + read_tool_name = "read_text_file" if has_filesystem_runtime else "read_skill" + return format_skills_for_prompt(manifests, read_tool_name=read_tool_name) + + +async def build_instruction( + template: str, + *, + aggregator: "MCPAggregator | None" = None, + skill_manifests: Sequence["SkillManifest"] | None = None, + has_filesystem_runtime: bool = False, + context: Mapping[str, str] | None = None, +) -> str: + """ + Build an instruction string from a template with all placeholders resolved. + + This is the main entry point for building agent instructions. It: + 1. Creates an InstructionBuilder with the template + 2. Sets up resolvers for serverInstructions and agentSkills + 3. Applies any context values + 4. Builds and returns the final instruction + + Args: + template: The instruction template with {{placeholder}} patterns + aggregator: MCP aggregator for fetching server instructions + skill_manifests: List of skill manifests for {{agentSkills}} + has_filesystem_runtime: Whether filesystem runtime is available + context: Additional context values (env, workspaceRoot, etc.) + + Returns: + The fully resolved instruction string + """ + builder = InstructionBuilder(template) + + # Set up server instructions resolver + if aggregator is not None: + + async def resolve_server_instructions() -> str: + try: + instructions_data = await aggregator.get_server_instructions() + return format_server_instructions(instructions_data) + except Exception as e: + logger.warning(f"Failed to get server instructions: {e}") + return "" + + builder.set_resolver("serverInstructions", resolve_server_instructions) + + # Set up agent skills resolver + if skill_manifests is not None: + + async def resolve_agent_skills() -> str: + return format_agent_skills(skill_manifests, has_filesystem_runtime) + + builder.set_resolver("agentSkills", resolve_agent_skills) + + # Apply context values + if context: + builder.set_many(dict(context)) + + return await builder.build() + + +# ───────────────────────────────────────────────────────────────────────────── +# Agent Instruction Refresh +# ───────────────────────────────────────────────────────────────────────────── + @dataclass(frozen=True) class InstructionRefreshResult: updated_skill_manifests: bool = False - updated_instruction_context: bool = False updated_skill_registry: bool = False + updated_context: bool = False rebuilt_instruction: bool = False @@ -38,36 +228,77 @@ async def rebuild_agent_instruction( agent: object, *, skill_manifests: list[Any] | None = None, - instruction_context: Mapping[str, str] | None = None, skill_registry: Any | None = None, + context: Mapping[str, str] | None = None, ) -> InstructionRefreshResult: - """Serialize rebuilds and apply optional instruction context updates.""" + """ + Rebuild an agent's instruction from its template. + + This function: + 1. Optionally updates skill_manifests and skill_registry on the agent + 2. Optionally updates instruction_context if context is provided + 3. Builds the instruction using the agent's template and data sources + 4. Sets the new instruction on the agent + + Args: + agent: The agent to refresh (must implement McpInstructionCapable) + skill_manifests: Optional new skill manifests to set + skill_registry: Optional new skill registry to set + context: Optional new context values to set and use in the build. + If not provided, uses the agent's stored instruction_context. + + Returns: + InstructionRefreshResult indicating what was updated + """ lock = _get_instruction_lock(agent) async with lock: updated_skill_manifests = False - updated_instruction_context = False updated_skill_registry = False + updated_context = False rebuilt_instruction = False - if skill_manifests is not None and hasattr(agent, "set_skill_manifests"): + if not isinstance(agent, McpInstructionCapable): + return InstructionRefreshResult() + + # Update agent state if new values provided + if skill_manifests is not None: agent.set_skill_manifests(skill_manifests) updated_skill_manifests = True - if instruction_context is not None and hasattr(agent, "set_instruction_context"): - agent.set_instruction_context(dict(instruction_context)) - updated_instruction_context = True - - if skill_registry is not None and hasattr(agent, "skill_registry"): + if skill_registry is not None: agent.skill_registry = skill_registry updated_skill_registry = True - if hasattr(agent, "rebuild_instruction_templates"): - await agent.rebuild_instruction_templates() - rebuilt_instruction = True + if context is not None: + agent.set_instruction_context(dict(context)) + updated_context = True + + # Build the instruction using the agent's current state + template = agent.instruction_template + if not template: + return InstructionRefreshResult( + updated_skill_manifests=updated_skill_manifests, + updated_skill_registry=updated_skill_registry, + updated_context=updated_context, + ) + + # Use agent's stored context (which may have just been updated) + build_context = agent.instruction_context + + new_instruction = await build_instruction( + template, + aggregator=agent.aggregator, + skill_manifests=agent.skill_manifests, + has_filesystem_runtime=agent.has_filesystem_runtime, + context=build_context, + ) + + agent.set_instruction(new_instruction) + rebuilt_instruction = True return InstructionRefreshResult( updated_skill_manifests=updated_skill_manifests, - updated_instruction_context=updated_instruction_context, updated_skill_registry=updated_skill_registry, + updated_context=updated_context, rebuilt_instruction=rebuilt_instruction, ) diff --git a/src/fast_agent/interfaces.py b/src/fast_agent/interfaces.py index 607e91dea..ea3b4ea7a 100644 --- a/src/fast_agent/interfaces.py +++ b/src/fast_agent/interfaces.py @@ -32,7 +32,8 @@ if TYPE_CHECKING: from fast_agent.acp.acp_aware_mixin import ACPCommand, ACPModeInfo from fast_agent.acp.acp_context import ACPContext - from fast_agent.agents.agent_types import AgentType + from fast_agent.agents.agent_types import AgentConfig, AgentType + from fast_agent.context import Context from fast_agent.llm.model_info import ModelInfo __all__ = [ @@ -255,6 +256,12 @@ async def attach_llm( @property def initialized(self) -> bool: ... + instruction: str + config: "AgentConfig" + context: "Context | None" + + def set_instruction(self, instruction: str) -> None: ... + @runtime_checkable class StreamingAgentProtocol(AgentProtocol, Protocol): diff --git a/src/fast_agent/mcp/types.py b/src/fast_agent/mcp/types.py index d232a6b57..dd758f4dd 100644 --- a/src/fast_agent/mcp/types.py +++ b/src/fast_agent/mcp/types.py @@ -1,12 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Protocol, Sequence, runtime_checkable from fast_agent.interfaces import AgentProtocol if TYPE_CHECKING: from fast_agent.context import Context from fast_agent.mcp.mcp_aggregator import MCPAggregator + from fast_agent.skills import SkillManifest + from fast_agent.skills.registry import SkillRegistry from fast_agent.ui.console_display import ConsoleDisplay @@ -22,3 +24,25 @@ def display(self) -> "ConsoleDisplay": ... @property def context(self) -> "Context | None": ... + + @property + def instruction_template(self) -> str: ... + + @property + def instruction_context(self) -> dict[str, str]: ... + + @property + def skill_manifests(self) -> Sequence["SkillManifest"]: ... + + @property + def has_filesystem_runtime(self) -> bool: ... + + def set_skill_manifests(self, manifests: Sequence["SkillManifest"]) -> None: ... + + def set_instruction_context(self, context: dict[str, str]) -> None: ... + + @property + def skill_registry(self) -> "SkillRegistry | None": ... + + @skill_registry.setter + def skill_registry(self, value: "SkillRegistry | None") -> None: ... diff --git a/src/fast_agent/tools/elicitation.py b/src/fast_agent/tools/elicitation.py index a6815117a..5743a5700 100644 --- a/src/fast_agent/tools/elicitation.py +++ b/src/fast_agent/tools/elicitation.py @@ -198,8 +198,9 @@ def parse_schema_string(val: str) -> dict | None: title: str | None = None description: str | None = None - if isinstance(arguments.get("fields"), list): - fields = arguments.get("fields") + fields_value = arguments.get("fields") + if isinstance(fields_value, list): + fields = fields_value if len(fields) > 7: raise ValueError( f"Error: form requests {len(fields)} fields; the maximum allowed is 7." @@ -289,17 +290,20 @@ def parse_schema_string(val: str) -> dict | None: schema = parsed else: raise ValueError("Missing or invalid schema. Provide a JSON Schema object.") + if not isinstance(schema, dict): + raise ValueError("Missing or invalid schema. Provide a JSON Schema object.") + schema_dict: dict[str, Any] = schema msg = arguments.get("message") if isinstance(msg, str): message = msg - if isinstance(arguments.get("title"), str) and "title" not in schema: - schema["title"] = arguments.get("title") - if isinstance(arguments.get("description"), str) and "description" not in schema: - schema["description"] = arguments.get("description") - if isinstance(arguments.get("required"), list) and "required" not in schema: - schema["required"] = arguments.get("required") - if isinstance(arguments.get("properties"), dict) and "properties" not in schema: - schema["properties"] = arguments.get("properties") + if isinstance(arguments.get("title"), str) and "title" not in schema_dict: + schema_dict["title"] = arguments.get("title") + if isinstance(arguments.get("description"), str) and "description" not in schema_dict: + schema_dict["description"] = arguments.get("description") + if isinstance(arguments.get("required"), list) and "required" not in schema_dict: + schema_dict["required"] = arguments.get("required") + if isinstance(arguments.get("properties"), dict) and "properties" not in schema_dict: + schema_dict["properties"] = arguments.get("properties") elif ("type" in arguments and "properties" in arguments) or ( "$schema" in arguments and "properties" in arguments @@ -309,17 +313,25 @@ def parse_schema_string(val: str) -> dict | None: else: raise ValueError("Missing or invalid schema or fields in arguments.") - props = schema.get("properties", {}) if isinstance(schema.get("properties"), dict) else {} + if not isinstance(schema, dict): + raise ValueError("Missing or invalid schema or fields in arguments.") + + schema_dict: dict[str, Any] = schema + props = ( + schema_dict.get("properties", {}) + if isinstance(schema_dict.get("properties"), dict) + else {} + ) if len(props) > 7: raise ValueError(f"Error: schema requests {len(props)} fields; the maximum allowed is 7.") request_payload: dict[str, Any] = { - "prompt": message or schema.get("title") or "Please complete this form:", - "description": schema.get("description"), + "prompt": message or schema_dict.get("title") or "Please complete this form:", + "description": schema_dict.get("description"), "request_id": f"__human_input__{uuid.uuid4()}", "metadata": { "agent_name": agent_name or "Unknown Agent", - "requested_schema": schema, + "requested_schema": schema_dict, }, } diff --git a/src/fast_agent/tools/shell_runtime.py b/src/fast_agent/tools/shell_runtime.py index 4e9dd9c63..bf973d240 100644 --- a/src/fast_agent/tools/shell_runtime.py +++ b/src/fast_agent/tools/shell_runtime.py @@ -167,7 +167,9 @@ async def execute(self, arguments: dict[str, Any] | None = None) -> CallToolResu if is_windows: # Windows: CREATE_NEW_PROCESS_GROUP allows killing process tree - process_kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP + creation_flags = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0) + if creation_flags: + process_kwargs["creationflags"] = creation_flags else: # Unix: start_new_session creates new process group process_kwargs["start_new_session"] = True @@ -256,7 +258,9 @@ async def watchdog() -> None: if is_windows: # Windows: try to signal the entire process group before terminating try: - process.send_signal(signal.CTRL_BREAK_EVENT) + ctrl_break = getattr(signal, "CTRL_BREAK_EVENT", None) + if ctrl_break is not None: + process.send_signal(ctrl_break) await asyncio.sleep(2) except AttributeError: # Older Python/asyncio may not support send_signal on Windows diff --git a/src/fast_agent/ui/command_payloads.py b/src/fast_agent/ui/command_payloads.py new file mode 100644 index 000000000..c6d96352f --- /dev/null +++ b/src/fast_agent/ui/command_payloads.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass +from typing import Literal + + +class CommandBase: + kind: str + + +@dataclass(frozen=True, slots=True) +class ShowUsageCommand(CommandBase): + kind: Literal["show_usage"] = "show_usage" + + +@dataclass(frozen=True, slots=True) +class ShowSystemCommand(CommandBase): + kind: Literal["show_system"] = "show_system" + + +@dataclass(frozen=True, slots=True) +class ShowMarkdownCommand(CommandBase): + kind: Literal["show_markdown"] = "show_markdown" + + +@dataclass(frozen=True, slots=True) +class ShowMcpStatusCommand(CommandBase): + kind: Literal["show_mcp_status"] = "show_mcp_status" + + +@dataclass(frozen=True, slots=True) +class ListToolsCommand(CommandBase): + kind: Literal["list_tools"] = "list_tools" + + +@dataclass(frozen=True, slots=True) +class ListPromptsCommand(CommandBase): + kind: Literal["list_prompts"] = "list_prompts" + + +@dataclass(frozen=True, slots=True) +class ListSkillsCommand(CommandBase): + kind: Literal["list_skills"] = "list_skills" + + +@dataclass(frozen=True, slots=True) +class ShowHistoryCommand(CommandBase): + agent: str | None + kind: Literal["show_history"] = "show_history" + + +@dataclass(frozen=True, slots=True) +class ClearCommand(CommandBase): + kind: Literal["clear_history", "clear_last"] + agent: str | None + + +@dataclass(frozen=True, slots=True) +class SkillsCommand(CommandBase): + action: str + argument: str | None + kind: Literal["skills_command"] = "skills_command" + + +@dataclass(frozen=True, slots=True) +class SelectPromptCommand(CommandBase): + prompt_name: str | None + prompt_index: int | None + kind: Literal["select_prompt"] = "select_prompt" + + +@dataclass(frozen=True, slots=True) +class SwitchAgentCommand(CommandBase): + agent_name: str + kind: Literal["switch_agent"] = "switch_agent" + + +@dataclass(frozen=True, slots=True) +class SaveHistoryCommand(CommandBase): + filename: str | None + kind: Literal["save_history"] = "save_history" + + +@dataclass(frozen=True, slots=True) +class LoadHistoryCommand(CommandBase): + filename: str | None + error: str | None + kind: Literal["load_history"] = "load_history" + + +CommandPayload = ( + ShowUsageCommand + | ShowSystemCommand + | ShowMarkdownCommand + | ShowMcpStatusCommand + | ListToolsCommand + | ListPromptsCommand + | ListSkillsCommand + | ShowHistoryCommand + | ClearCommand + | SkillsCommand + | SelectPromptCommand + | SwitchAgentCommand + | SaveHistoryCommand + | LoadHistoryCommand +) diff --git a/src/fast_agent/ui/enhanced_prompt.py b/src/fast_agent/ui/enhanced_prompt.py index c0b1c5981..a657c9532 100644 --- a/src/fast_agent/ui/enhanced_prompt.py +++ b/src/fast_agent/ui/enhanced_prompt.py @@ -8,6 +8,7 @@ import shlex import subprocess import tempfile +from collections.abc import Callable, Iterable from importlib.metadata import version from pathlib import Path from typing import TYPE_CHECKING, Any @@ -22,10 +23,28 @@ from rich import print as rich_print from fast_agent.agents.agent_types import AgentType +from fast_agent.agents.workflow.parallel_agent import ParallelAgent +from fast_agent.agents.workflow.router_agent import RouterAgent from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL, FAST_AGENT_REMOVED_METADATA_CHANNEL from fast_agent.core.exceptions import PromptExitError from fast_agent.llm.model_info import ModelInfo from fast_agent.mcp.types import McpAgentProtocol +from fast_agent.ui.command_payloads import ( + ClearCommand, + CommandBase, + CommandPayload, + ListToolsCommand, + LoadHistoryCommand, + SaveHistoryCommand, + SelectPromptCommand, + ShowHistoryCommand, + ShowMarkdownCommand, + ShowMcpStatusCommand, + ShowSystemCommand, + ShowUsageCommand, + SkillsCommand, + SwitchAgentCommand, +) from fast_agent.ui.mcp_display import render_mcp_status if TYPE_CHECKING: @@ -47,6 +66,60 @@ in_multiline_mode = False +def _show_system_cmd() -> ShowSystemCommand: + return ShowSystemCommand() + + +def _show_usage_cmd() -> ShowUsageCommand: + return ShowUsageCommand() + + +def _show_markdown_cmd() -> ShowMarkdownCommand: + return ShowMarkdownCommand() + + +def _show_mcp_status_cmd() -> ShowMcpStatusCommand: + return ShowMcpStatusCommand() + + +def _list_tools_cmd() -> ListToolsCommand: + return ListToolsCommand() + + +def _switch_agent_cmd(agent_name: str) -> SwitchAgentCommand: + return SwitchAgentCommand(agent_name=agent_name) + + +def _show_history_cmd(target_agent: str | None) -> ShowHistoryCommand: + return ShowHistoryCommand(agent=target_agent) + + +def _clear_last_cmd(target_agent: str | None) -> ClearCommand: + return ClearCommand(kind="clear_last", agent=target_agent) + + +def _clear_history_cmd(target_agent: str | None) -> ClearCommand: + return ClearCommand(kind="clear_history", agent=target_agent) + + +def _save_history_cmd(filename: str | None) -> SaveHistoryCommand: + return SaveHistoryCommand(filename=filename) + + +def _load_history_cmd(filename: str | None, error: str | None) -> LoadHistoryCommand: + return LoadHistoryCommand(filename=filename, error=error) + + +def _select_prompt_cmd( + prompt_index: int | None, prompt_name: str | None +) -> SelectPromptCommand: + return SelectPromptCommand(prompt_index=prompt_index, prompt_name=prompt_name) + + +def _skills_cmd(action: str, argument: str | None) -> SkillsCommand: + return SkillsCommand(action=action, argument=argument) + + def _extract_alert_flags_from_meta(blocks) -> set[str]: flags: set[str] = set() for block in blocks or []: @@ -131,12 +204,12 @@ async def _display_agent_info_helper(agent_name: str, agent_provider: "AgentApp skill_count = 0 # Handle different agent types - if agent.agent_type == AgentType.PARALLEL: + if isinstance(agent, ParallelAgent): # Count child agents for parallel agents child_count = 0 - if hasattr(agent, "fan_out_agents") and agent.fan_out_agents: + if agent.fan_out_agents: child_count += len(agent.fan_out_agents) - if hasattr(agent, "fan_in_agent") and agent.fan_in_agent: + if agent.fan_in_agent: child_count += 1 if child_count > 0: @@ -144,12 +217,10 @@ async def _display_agent_info_helper(agent_name: str, agent_provider: "AgentApp rich_print( f"[dim]Agent [/dim][blue]{agent_name}[/blue][dim]:[/dim] {child_count:,}[dim] {child_word}[/dim]" ) - elif agent.agent_type == AgentType.ROUTER: + elif isinstance(agent, RouterAgent): # Count child agents for router agents child_count = 0 - if hasattr(agent, "routing_agents") and agent.routing_agents: - child_count = len(agent.routing_agents) - elif hasattr(agent, "agents") and agent.agents: + if agent.agents: child_count = len(agent.agents) if child_count > 0: @@ -213,37 +284,37 @@ async def _display_agent_info_helper(agent_name: str, agent_provider: "AgentApp async def _display_all_agents_with_hierarchy( - available_agents: list[str], agent_provider: "AgentApp | None" + available_agents: Iterable[str], agent_provider: "AgentApp | None" ) -> None: """Display all agents with tree structure for workflow agents.""" + agent_list = list(available_agents) # Track which agents are children to avoid displaying them twice child_agents = set() # First pass: identify all child agents - for agent_name in available_agents: + for agent_name in agent_list: try: if agent_provider is None: continue agent = agent_provider._agent(agent_name) - if agent.agent_type == AgentType.PARALLEL: - if hasattr(agent, "fan_out_agents") and agent.fan_out_agents: + if isinstance(agent, ParallelAgent): + if agent.fan_out_agents: for child_agent in agent.fan_out_agents: - child_agents.add(child_agent.name) - if hasattr(agent, "fan_in_agent") and agent.fan_in_agent: + if child_agent.name: + child_agents.add(child_agent.name) + if agent.fan_in_agent and agent.fan_in_agent.name: child_agents.add(agent.fan_in_agent.name) - elif agent.agent_type == AgentType.ROUTER: - if hasattr(agent, "routing_agents") and agent.routing_agents: - for child_agent in agent.routing_agents: - child_agents.add(child_agent.name) - elif hasattr(agent, "agents") and agent.agents: + elif isinstance(agent, RouterAgent): + if agent.agents: for child_agent in agent.agents: - child_agents.add(child_agent.name) + if child_agent.name: + child_agents.add(child_agent.name) except Exception: continue # Second pass: display agents (parents with children, standalone agents without children) - for agent_name in sorted(available_agents): + for agent_name in sorted(agent_list): # Skip if this agent is a child of another agent if agent_name in child_agents: continue @@ -271,12 +342,12 @@ async def _display_parallel_children(parallel_agent, agent_provider: "AgentApp | children = [] # Collect fan-out agents - if hasattr(parallel_agent, "fan_out_agents") and parallel_agent.fan_out_agents: + if parallel_agent.fan_out_agents: for child_agent in parallel_agent.fan_out_agents: children.append(child_agent) # Collect fan-in agent - if hasattr(parallel_agent, "fan_in_agent") and parallel_agent.fan_in_agent: + if parallel_agent.fan_in_agent is not None: children.append(parallel_agent.fan_in_agent) # Display children with tree formatting @@ -291,10 +362,8 @@ async def _display_router_children(router_agent, agent_provider: "AgentApp | Non children = [] # Collect routing agents - if hasattr(router_agent, "routing_agents") and router_agent.routing_agents: - children = router_agent.routing_agents - elif hasattr(router_agent, "agents") and router_agent.agents: - children = router_agent.agents + if router_agent.agents: + children = list(router_agent.agents) # Display children with tree formatting for i, child_agent in enumerate(children): @@ -353,8 +422,7 @@ class AgentCompleter(Completer): def __init__( self, agents: list[str], - commands: list[str] = None, - agent_types: dict = None, + agent_types: dict[str, AgentType] | None = None, is_human_input: bool = False, ) -> None: self.agents = agents @@ -376,7 +444,6 @@ def __init__( "help": "Show commands and shortcuts", "EXIT": "Exit fast-agent, terminating any running workflows", "STOP": "Stop this prompting session and move to next workflow step", - **(commands or {}), # Allow custom commands to be passed in } if is_human_input: self.commands.pop("agents") @@ -568,11 +635,19 @@ def get_text_from_editor(initial_text: str = "") -> str: return edited_text.strip() # Added strip() to remove trailing newlines often added by editors +class AgentKeyBindings(KeyBindings): + agent_provider: "AgentApp | None" = None + current_agent_name: str | None = None + + def create_keybindings( - on_toggle_multiline=None, app=None, agent_provider: "AgentApp | None" = None, agent_name=None -): + on_toggle_multiline: Callable[[bool], None] | None = None, + app: Any | None = None, + agent_provider: "AgentApp | None" = None, + agent_name: str | None = None, +) -> AgentKeyBindings: """Create custom key bindings.""" - kb = KeyBindings() + kb = AgentKeyBindings() @kb.add("c-m", filter=Condition(lambda: not in_multiline_mode)) def _(event) -> None: @@ -675,7 +750,7 @@ async def _(event) -> None: return kb -def parse_special_input(text: str) -> str | dict[str, Any]: +def parse_special_input(text: str) -> str | CommandPayload: stripped = text.lstrip() if stripped.startswith("/"): cmd_line = stripped.splitlines()[0] @@ -694,16 +769,16 @@ def parse_special_input(text: str) -> str | dict[str, Any]: if cmd == "agents": return "LIST_AGENTS" if cmd == "system": - return "SHOW_SYSTEM" + return _show_system_cmd() if cmd == "usage": - return "SHOW_USAGE" + return _show_usage_cmd() if cmd == "history": target_agent = None if len(cmd_parts) > 1: candidate = cmd_parts[1].strip() if candidate: target_agent = candidate - return {"show_history": {"agent": target_agent}} + return _show_history_cmd(target_agent) if cmd == "clear": target_agent = None if len(cmd_parts) > 1: @@ -715,49 +790,49 @@ def parse_special_input(text: str) -> str | dict[str, Any]: candidate = tokens[1].strip() if candidate: target_agent = candidate - return {"clear_last": {"agent": target_agent}} + return _clear_last_cmd(target_agent) target_agent = remainder - return {"clear_history": {"agent": target_agent}} + return _clear_history_cmd(target_agent) if cmd == "markdown": - return "MARKDOWN" + return _show_markdown_cmd() if cmd in ("save_history", "save"): filename = ( cmd_parts[1].strip() if len(cmd_parts) > 1 and cmd_parts[1].strip() else None ) - return {"save_history": True, "filename": filename} + return _save_history_cmd(filename) if cmd in ("load_history", "load"): filename = ( cmd_parts[1].strip() if len(cmd_parts) > 1 and cmd_parts[1].strip() else None ) if not filename: - return {"load_history": True, "error": "Filename required for load_history"} - return {"load_history": True, "filename": filename} + return _load_history_cmd(None, "Filename required for load_history") + return _load_history_cmd(filename, None) if cmd in ("mcpstatus", "mcp"): - return {"show_mcp_status": True} + return _show_mcp_status_cmd() if cmd == "prompt": if len(cmd_parts) > 1: prompt_arg = cmd_parts[1].strip() if prompt_arg.isdigit(): - return {"select_prompt": True, "prompt_index": int(prompt_arg)} - return f"SELECT_PROMPT:{prompt_arg}" - return {"select_prompt": True, "prompt_name": None} + return _select_prompt_cmd(int(prompt_arg), None) + return _select_prompt_cmd(None, prompt_arg) + return _select_prompt_cmd(None, None) if cmd == "tools": - return {"list_tools": True} + return _list_tools_cmd() if cmd == "skills": remainder = cmd_parts[1].strip() if len(cmd_parts) > 1 else "" if not remainder: - return {"skills_command": {"action": "list", "argument": None}} + return _skills_cmd("list", None) tokens = remainder.split(maxsplit=1) action = tokens[0].lower() argument = tokens[1].strip() if len(tokens) > 1 else None - return {"skills_command": {"action": action, "argument": argument}} + return _skills_cmd(action, argument) if cmd == "exit": return "EXIT" if cmd.lower() == "stop": return "STOP" if cmd_line and cmd_line.startswith("@"): - return f"SWITCH:{cmd_line[1:].strip()}" + return _switch_agent_cmd(cmd_line[1:].strip()) return text @@ -768,12 +843,12 @@ async def get_enhanced_input( show_default: bool = False, show_stop_hint: bool = False, multiline: bool = False, - available_agent_names: list[str] = None, - agent_types: dict[str, AgentType] = None, + available_agent_names: list[str] | None = None, + agent_types: dict[str, AgentType] | None = None, is_human_input: bool = False, toolbar_color: str = "ansiblue", agent_provider: "AgentApp | None" = None, -) -> str: +) -> str | CommandPayload: """ Enhanced input with advanced prompt_toolkit features. @@ -790,7 +865,7 @@ async def get_enhanced_input( agent_provider: Optional AgentApp for displaying agent info Returns: - User input string + User input string or parsed command payload """ global in_multiline_mode, available_agents, help_message_shown @@ -862,15 +937,16 @@ def get_toolbar(): ) if not model_name: - model_name = getattr(agent.config, "model", None) - if not model_name and getattr(agent.config, "default_request_params", None): - model_name = getattr(agent.config.default_request_params, "model", None) + model_name = agent.config.model + if not model_name and agent.config.default_request_params: + model_name = agent.config.default_request_params.model if not model_name: - context = getattr(agent, "context", None) or getattr( - agent_provider, "context", None - ) - config_obj = getattr(context, "config", None) if context else None - model_name = getattr(config_obj, "default_model", None) + try: + context = agent.context + except Exception: + context = None + if context and context.config: + model_name = context.config.default_model if model_name: max_len = 25 @@ -1080,8 +1156,12 @@ def _style_flag(letter: str, supported: bool) -> str: active_agent = shell_agent if active_agent is None: active_agent = agent_provider._agent(agent_name) - agent_context = active_agent._context or active_agent.context - logger_settings = agent_context.config.logger + try: + agent_context = active_agent.context + except Exception: + agent_context = None + if agent_context and agent_context.config: + logger_settings = agent_context.config.logger except Exception: # If we can't get the agent or its context, logger_settings stays None pass @@ -1115,13 +1195,17 @@ def _style_flag(letter: str, supported: bool) -> str: ) # Show model source if configured via env var or config file - model_source = getattr(agent_context.config, "model_source", None) + model_source = ( + getattr(agent_context.config, "model_source", None) + if agent_context and agent_context.config + else None + ) if model_source: rich_print(f"[dim]Model selected via {model_source}[/dim]") # Show HuggingFace model and provider info if applicable try: - if active_agent.llm: + if active_agent and active_agent.llm: get_hf_info = getattr(active_agent.llm, "get_hf_display_info", None) if get_hf_info: hf_info = get_hf_info() @@ -1188,8 +1272,8 @@ def _style_flag(letter: str, supported: bool) -> str: async def get_selection_input( prompt_text: str, - options: list[str] = None, - default: str = None, + options: list[str] | None = None, + default: str | None = None, allow_cancel: bool = True, complete_options: bool = True, ) -> str | None: @@ -1235,7 +1319,7 @@ async def get_selection_input( async def get_argument_input( arg_name: str, - description: str = None, + description: str | None = None, required: bool = True, ) -> str | None: """ @@ -1284,8 +1368,8 @@ async def get_argument_input( async def handle_special_commands( - command: Any, agent_app: "AgentApp | None" = None -) -> bool | dict[str, Any]: + command: str | CommandPayload | None, agent_app: "AgentApp | bool | None" = None +) -> bool | CommandPayload: """ Handle special input commands. @@ -1300,9 +1384,8 @@ async def handle_special_commands( if not command: return False - # If command is already a dictionary, it has been pre-processed - # Just return it directly (like when /prompts converts to select_prompt dict) - if isinstance(command, dict): + # If command is already a command payload, it has been pre-processed. + if isinstance(command, CommandBase): return command global agent_histories @@ -1358,16 +1441,13 @@ async def handle_special_commands( return True elif command == "SHOW_USAGE": - # Return a dictionary to signal that usage should be shown - return {"show_usage": True} + return _show_usage_cmd() elif command == "SHOW_SYSTEM": - # Return a dictionary to signal that system prompt should be shown - return {"show_system": True} + return _show_system_cmd() elif command == "MARKDOWN": - # Return a dictionary to signal that markdown display should be shown - return {"show_markdown": True} + return _show_markdown_cmd() elif command == "SELECT_PROMPT" or ( isinstance(command, str) and command.startswith("SELECT_PROMPT:") @@ -1380,7 +1460,7 @@ async def handle_special_commands( prompt_name = command.split(":", 1)[1].strip() # Return a dictionary with a select_prompt action to be handled by the caller - return {"select_prompt": True, "prompt_name": prompt_name} + return _select_prompt_cmd(None, prompt_name) else: rich_print( "[yellow]Prompt selection is not available outside of an agent context[/yellow]" @@ -1392,7 +1472,7 @@ async def handle_special_commands( if agent_name in available_agents: if agent_app: # The parameter can be the actual agent_app or just True to enable switching - return {"switch_agent": agent_name} + return _switch_agent_cmd(agent_name) else: rich_print("[yellow]Agent switching not available in this context[/yellow]") else: diff --git a/src/fast_agent/ui/interactive_prompt.py b/src/fast_agent/ui/interactive_prompt.py index d26716ced..f51e77fe8 100644 --- a/src/fast_agent/ui/interactive_prompt.py +++ b/src/fast_agent/ui/interactive_prompt.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from fast_agent.core.agent_app import AgentApp - + from fast_agent.ui.command_payloads import CommandPayload from mcp.types import Prompt, PromptMessage from rich import print as rich_print @@ -48,6 +48,7 @@ ) from fast_agent.skills.registry import format_skills_for_prompt from fast_agent.types import PromptMessageExtended +from fast_agent.ui.command_payloads import CommandBase from fast_agent.ui.enhanced_prompt import ( _display_agent_info_helper, get_argument_input, @@ -139,10 +140,11 @@ async def prompt_loop( command_result = await handle_special_commands(user_input, True) # Check if we should switch agents - if isinstance(command_result, dict): - command_dict: dict[str, Any] = command_result - if "switch_agent" in command_dict: - new_agent = command_dict["switch_agent"] + if isinstance(command_result, CommandBase): + command_payload: CommandPayload = command_result + kind = command_payload.kind + if kind == "switch_agent": + new_agent = command_payload.agent_name if new_agent in available_agents_set: agent = new_agent # Display new agent info immediately when switching @@ -153,14 +155,14 @@ async def prompt_loop( rich_print(f"[red]Agent '{new_agent}' not found[/red]") continue # Keep the existing list_prompts handler for backward compatibility - elif "list_prompts" in command_dict: + elif kind == "list_prompts": # Use the prompt_provider directly await self._list_prompts(prompt_provider, agent) continue - elif "select_prompt" in command_dict: + elif kind == "select_prompt": # Handle prompt selection, using both list_prompts and apply_prompt - prompt_name = command_dict.get("prompt_name") - prompt_index = command_dict.get("prompt_index") + prompt_name = command_payload.prompt_name + prompt_index = command_payload.prompt_index # If a specific index was provided (from /prompt ) if prompt_index is not None: @@ -190,28 +192,26 @@ async def prompt_loop( # Use the name-based selection await self._select_prompt(prompt_provider, agent, prompt_name) continue - elif "list_tools" in command_dict: + elif kind == "list_tools": # Handle tools list display await self._list_tools(prompt_provider, agent) continue - elif "list_skills" in command_dict: + elif kind == "list_skills": await self._list_skills(prompt_provider, agent) continue - elif "skills_command" in command_dict: - await self._handle_skills_command( - prompt_provider, agent, command_dict["skills_command"] - ) + elif kind == "skills_command": + payload = { + "action": command_payload.action, + "argument": command_payload.argument, + } + await self._handle_skills_command(prompt_provider, agent, payload) continue - elif "show_usage" in command_dict: + elif kind == "show_usage": # Handle usage display await self._show_usage(prompt_provider, agent) continue - elif "show_history" in command_dict: - history_info = command_dict.get("show_history") - history_agent = ( - history_info.get("agent") if isinstance(history_info, dict) else None - ) - target_agent = history_agent or agent + elif kind == "show_history": + target_agent = command_payload.agent or agent try: agent_obj = prompt_provider._agent(target_agent) except Exception: @@ -222,12 +222,8 @@ async def prompt_loop( usage = getattr(agent_obj, "usage_accumulator", None) display_history_overview(target_agent, history, usage) continue - elif "clear_last" in command_dict: - clear_info = command_dict.get("clear_last") - clear_agent = ( - clear_info.get("agent") if isinstance(clear_info, dict) else None - ) - target_agent = clear_agent or agent + elif kind == "clear_last": + target_agent = command_payload.agent or agent try: agent_obj = prompt_provider._agent(target_agent) except Exception: @@ -257,12 +253,8 @@ async def prompt_loop( f"[yellow]No messages to remove for agent '{target_agent}'.[/yellow]" ) continue - elif "clear_history" in command_dict: - clear_info = command_dict.get("clear_history") - clear_agent = ( - clear_info.get("agent") if isinstance(clear_info, dict) else None - ) - target_agent = clear_agent or agent + elif kind == "clear_history": + target_agent = command_payload.agent or agent try: agent_obj = prompt_provider._agent(target_agent) except Exception: @@ -284,21 +276,21 @@ async def prompt_loop( f"[yellow]Agent '{target_agent}' does not support clearing history.[/yellow]" ) continue - elif "show_system" in command_dict: + elif kind == "show_system": # Handle system prompt display await self._show_system(prompt_provider, agent) continue - elif "show_markdown" in command_dict: + elif kind == "show_markdown": # Handle markdown display await self._show_markdown(prompt_provider, agent) continue - elif "show_mcp_status" in command_dict: + elif kind == "show_mcp_status": rich_print() await show_mcp_status(agent, prompt_provider) continue - elif "save_history" in command_dict: + elif kind == "save_history": # Save history for the current agent - filename = command_dict.get("filename") + filename = command_payload.filename try: agent_obj = prompt_provider._agent(agent) @@ -314,13 +306,13 @@ async def prompt_loop( if result: rich_print(f"[green]{result}[/green]") continue - elif "load_history" in command_dict: + elif kind == "load_history": # Load history for the current agent - if command_dict.get("error"): - rich_print(f"[red]{command_dict['error']}[/red]") + if command_payload.error: + rich_print(f"[red]{command_payload.error}[/red]") continue - filename = command_dict.get("filename") + filename = command_payload.filename try: from fast_agent.mcp.prompts.prompt_load import load_history_into_agent @@ -328,16 +320,12 @@ async def prompt_loop( agent_obj = prompt_provider._agent(agent) # Load history directly without triggering LLM call - if hasattr(agent_obj, "rebuild_instruction_templates"): - await agent_obj.rebuild_instruction_templates() load_history_into_agent(agent_obj, Path(filename)) msg_count = len(agent_obj.message_history) rich_print( f"[green]Loaded {msg_count} messages from {filename}[/green]" ) - if hasattr(agent_obj, "rebuild_instruction_templates"): - await agent_obj.rebuild_instruction_templates() except FileNotFoundError: rich_print(f"[red]File not found: {filename}[/red]") except Exception as e: @@ -346,16 +334,19 @@ async def prompt_loop( # Skip further processing if: # 1. The command was handled (command_result is truthy) - # 2. The original input was a dictionary (special command like /prompt) - # 3. The command result itself is a dictionary (special command handling result) + # 2. The original input was a command payload (special command like /prompt) + # 3. The command result itself is a command payload (special command handling result) # This fixes the issue where /prompt without arguments gets sent to the LLM if ( command_result - or isinstance(user_input, dict) - or isinstance(command_result, dict) + or isinstance(user_input, CommandBase) + or isinstance(command_result, CommandBase) ): continue + if not isinstance(user_input, str): + continue + if user_input.upper() == "STOP": return result if user_input == "": @@ -1130,9 +1121,7 @@ async def _add_skill( repo_ref = getattr(marketplace[0], "repo_ref", None) repo_hint = f"{repo_url}@{repo_ref}" if repo_ref else repo_url if repo_hint: - rich_print( - f"[dim]Repository: {format_marketplace_display_url(repo_hint)}[/dim]" - ) + rich_print(f"[dim]Repository: {format_marketplace_display_url(repo_hint)}[/dim]") self._render_marketplace_skills(marketplace) selection = await get_selection_input( "Install skill by number or name (empty to cancel): ", @@ -1148,9 +1137,7 @@ async def _add_skill( return try: - install_path = await install_marketplace_skill( - skill, destination_root=manager_dir - ) + install_path = await install_marketplace_skill(skill, destination_root=manager_dir) except Exception as exc: # noqa: BLE001 rich_print(f"[red]Failed to install skill: {exc}[/red]") return @@ -1193,9 +1180,7 @@ async def _remove_skill( rich_print(f"[green]Removed skill:[/green] {manifest.name}") await self._refresh_agent_skills(prompt_provider, agent_name) - async def _refresh_agent_skills( - self, prompt_provider: "AgentApp", agent_name: str - ) -> None: + async def _refresh_agent_skills(self, prompt_provider: "AgentApp", agent_name: str) -> None: assert hasattr(prompt_provider, "_agent"), ( "Interactive prompt expects an AgentApp with _agent()" ) @@ -1206,9 +1191,7 @@ async def _refresh_agent_skills( ) instruction_context = None try: - skills_text = format_skills_for_prompt( - manifests, read_tool_name="read_skill" - ) + skills_text = format_skills_for_prompt(manifests, read_tool_name="read_skill") instruction_context = {"agentSkills": skills_text} except Exception: instruction_context = None @@ -1232,9 +1215,7 @@ def _render_marketplace_skills(self, marketplace: list[Any]) -> None: rich_print("") rich_print(f"[bold]{bundle_name}[/bold]") if bundle_description: - wrapped_lines = textwrap.wrap( - bundle_description.strip(), width=72 - ) + wrapped_lines = textwrap.wrap(bundle_description.strip(), width=72) for line in wrapped_lines: rich_print(f"[white]{line.strip()}[/white]") rich_print("") diff --git a/src/fast_agent/ui/streaming.py b/src/fast_agent/ui/streaming.py index 99defdf3a..b41141ec7 100644 --- a/src/fast_agent/ui/streaming.py +++ b/src/fast_agent/ui/streaming.py @@ -81,7 +81,9 @@ def __init__( self._convert_literal_newlines = False self._pending_literal_backslashes = "" initial_renderable = ( - Text("", style=self._plain_text_style) if self._use_plain_text else Markdown("") + Text("", style=self._plain_text_style or "") + if self._use_plain_text + else Markdown("") ) refresh_rate = ( PLAIN_STREAM_REFRESH_PER_SECOND diff --git a/src/fast_agent/ui/streaming_buffer.py b/src/fast_agent/ui/streaming_buffer.py index 28b1f997a..b417e3b83 100644 --- a/src/fast_agent/ui/streaming_buffer.py +++ b/src/fast_agent/ui/streaming_buffer.py @@ -249,21 +249,23 @@ def _find_tables(self, text: str) -> list[Table]: tables = [] for i, token in enumerate(tokens): - if token.type == "table_open" and token.map: + token_map = token.map + if token.type == "table_open" and token_map is not None: # Find tbody within this table to extract header tbody_start_line = None # Look ahead for tbody for j in range(i + 1, len(tokens)): - if tokens[j].type == "tbody_open" and tokens[j].map: - tbody_start_line = tokens[j].map[0] + tbody_map = tokens[j].map + if tokens[j].type == "tbody_open" and tbody_map is not None: + tbody_start_line = tbody_map[0] break elif tokens[j].type == "table_close": break if tbody_start_line is not None: - table_start_line = token.map[0] - table_end_line = token.map[1] + table_start_line = token_map[0] + table_end_line = token_map[1] # Calculate positions start_pos = sum(len(line) + 1 for line in lines[:table_start_line]) diff --git a/tests/unit/fast_agent/core/test_instruction_refresh.py b/tests/unit/fast_agent/core/test_instruction_refresh.py index df0a172a8..f7385a329 100644 --- a/tests/unit/fast_agent/core/test_instruction_refresh.py +++ b/tests/unit/fast_agent/core/test_instruction_refresh.py @@ -1,54 +1,201 @@ +"""Tests for instruction building and refresh utilities.""" + import asyncio -from fast_agent.core.instruction_refresh import rebuild_agent_instruction +from fast_agent.core.instruction_refresh import ( + McpInstructionCapable, + build_instruction, + format_server_instructions, + rebuild_agent_instruction, +) + + +class StubAggregator: + """Stub aggregator that returns predefined server instructions.""" + + def __init__(self, instructions: dict[str, tuple[str | None, list[str]]] | None = None): + self._instructions = instructions or {} + + async def get_server_instructions(self) -> dict[str, tuple[str | None, list[str]]]: + return self._instructions class StubAgent: - def __init__(self) -> None: - self.skill_registry = None - self.manifests = None - self.context = None - self.instruction_context = None - self.rebuild_calls = 0 + """Stub that implements McpInstructionCapable for testing.""" + + def __init__( + self, + template: str = "Test instruction", + aggregator: StubAggregator | None = None, + ) -> None: + self._instruction = template + self._instruction_template = template + self._instruction_context: dict[str, str] = {} + self._skill_manifests: list = [] + self._skill_registry = None + self._aggregator = aggregator or StubAggregator() + self._has_filesystem_runtime = False + + @property + def instruction(self) -> str: + return self._instruction + + def set_instruction(self, instruction: str) -> None: + self._instruction = instruction + + @property + def instruction_template(self) -> str: + return self._instruction_template + + @property + def instruction_context(self) -> dict[str, str]: + return self._instruction_context + + @property + def aggregator(self): + return self._aggregator + + @property + def skill_manifests(self) -> list: + return self._skill_manifests + + @property + def skill_registry(self): + return self._skill_registry + + @skill_registry.setter + def skill_registry(self, value): + self._skill_registry = value def set_skill_manifests(self, manifests) -> None: - self.manifests = list(manifests) + self._skill_manifests = list(manifests) + + def set_instruction_context(self, context: dict[str, str]) -> None: + self._instruction_context.update(context) + + @property + def has_filesystem_runtime(self) -> bool: + return self._has_filesystem_runtime + + +# Ensure StubAgent is recognized as McpInstructionCapable +assert isinstance(StubAgent(), McpInstructionCapable) + + +# ───────────────────────────────────────────────────────────────────────────── +# Test format_server_instructions +# ───────────────────────────────────────────────────────────────────────────── + + +def test_format_server_instructions_empty() -> None: + result = format_server_instructions({}) + assert result == "" + + +def test_format_server_instructions_with_data() -> None: + data = { + "test-server": ("Do helpful things", ["tool1", "tool2"]), + } + result = format_server_instructions(data) + assert "test-server" in result + assert "Do helpful things" in result + assert "test-server__tool1" in result + assert "test-server__tool2" in result + + +def test_format_server_instructions_skips_none() -> None: + data = { + "server1": ("Instructions", ["tool1"]), + "server2": (None, ["tool2"]), # Should be skipped + } + result = format_server_instructions(data) + assert "server1" in result + assert "server2" not in result + + +# ───────────────────────────────────────────────────────────────────────────── +# Test build_instruction +# ───────────────────────────────────────────────────────────────────────────── - def set_instruction_context(self, context) -> None: - self.instruction_context = dict(context) - async def rebuild_instruction_templates(self) -> None: - await asyncio.sleep(0) - self.rebuild_calls += 1 +def test_build_instruction_resolves_builtins() -> None: + template = "Today is {{currentDate}}. Platform: {{hostPlatform}}" + result = asyncio.run(build_instruction(template)) + # Should not contain the placeholders anymore + assert "{{currentDate}}" not in result + assert "{{hostPlatform}}" not in result + + +def test_build_instruction_with_context() -> None: + template = "Root: {{workspaceRoot}}" + result = asyncio.run(build_instruction(template, context={"workspaceRoot": "/test/path"})) + assert result == "Root: /test/path" + + +def test_build_instruction_with_aggregator() -> None: + template = "{{serverInstructions}}" + aggregator = StubAggregator({"my-server": ("Be helpful", ["do_thing"])}) + result = asyncio.run(build_instruction(template, aggregator=aggregator)) + assert "my-server" in result + assert "Be helpful" in result + + +# ───────────────────────────────────────────────────────────────────────────── +# Test rebuild_agent_instruction +# ───────────────────────────────────────────────────────────────────────────── def test_rebuild_agent_instruction_updates_fields() -> None: + agent = StubAgent(template="Hello {{workspaceRoot}}") + result = asyncio.run( + rebuild_agent_instruction( + agent, + context={"workspaceRoot": "/my/path"}, + ) + ) + assert agent.instruction == "Hello /my/path" + assert agent.instruction_context == {"workspaceRoot": "/my/path"} + assert result.updated_context is True + assert result.rebuilt_instruction is True + + +def test_rebuild_agent_instruction_updates_skill_manifests() -> None: agent = StubAgent() result = asyncio.run( rebuild_agent_instruction( agent, - skill_manifests=[object()], - instruction_context={"agentSkills": "skills"}, - skill_registry="registry", + skill_manifests=["manifest1", "manifest2"], ) ) - assert agent.manifests is not None - assert agent.instruction_context == {"agentSkills": "skills"} - assert agent.skill_registry == "registry" - assert agent.rebuild_calls == 1 + assert agent.skill_manifests == ["manifest1", "manifest2"] assert result.updated_skill_manifests is True - assert result.updated_instruction_context is True + + +def test_rebuild_agent_instruction_updates_skill_registry() -> None: + agent = StubAgent() + result = asyncio.run( + rebuild_agent_instruction( + agent, + skill_registry="my-registry", + ) + ) + assert agent.skill_registry == "my-registry" assert result.updated_skill_registry is True - assert result.rebuilt_instruction is True -def test_rebuild_agent_instruction_handles_missing_methods() -> None: +def test_rebuild_agent_instruction_handles_non_mcp_agent() -> None: class MinimalAgent: pass agent = MinimalAgent() result = asyncio.run(rebuild_agent_instruction(agent)) assert result.updated_skill_manifests is False - assert result.updated_instruction_context is False + assert result.updated_context is False assert result.updated_skill_registry is False assert result.rebuilt_instruction is False + + +def test_rebuild_agent_instruction_handles_empty_template() -> None: + agent = StubAgent(template="") + result = asyncio.run(rebuild_agent_instruction(agent)) + assert result.rebuilt_instruction is False diff --git a/typesafe.md b/typesafe.md new file mode 100644 index 000000000..7cfca9c6d --- /dev/null +++ b/typesafe.md @@ -0,0 +1,119 @@ +# Type Safety Plan (ty) + +This document describes the rules and approach we will use to make the codebase type-safe with `ty`. +It uses a small set of examples from the `ty` docs to anchor our conventions, then lays out a +reproducible plan for getting to a clean `ty check`. + +Sources (examples below are based on these pages): +- [ty rules](https://docs.astral.sh/ty/rules/) +- [ty suppression](https://docs.astral.sh/ty/suppression/) +- [ty configuration](https://docs.astral.sh/ty/configuration/) + +## Examples From ty (reference patterns) + +### Rule levels via CLI and config +Use rule-level settings to gradually tighten checks: + +```shell +ty check \ + --warn unused-ignore-comment \ + --ignore redundant-cast \ + --error possibly-missing-attribute \ + --error possibly-missing-import +``` + +Equivalent `pyproject.toml`: + +```toml +[tool.ty.rules] +unused-ignore-comment = "warn" +redundant-cast = "ignore" +possibly-missing-attribute = "error" +possibly-missing-import = "error" +``` + +### Targeted suppressions +Prefer narrow, rule-specific suppressions: + +```py +sum_three_numbers("one", 5) # ty: ignore[missing-argument, invalid-argument-type] +``` + +Multi-line suppression can go on the first or last line of the violation: + +```py +sum_three_numbers( # ty: ignore[missing-argument] + 3, + 2 +) +``` + +### Whole-function suppression +Use `@no_type_check` only when a function is intentionally dynamic: + +```py +from typing import no_type_check + +@no_type_check +def main(): + sum_three_numbers(1, 2) # no error for the missing argument +``` + +## Modern, Pythonic Typing Rules + +These rules are what we will follow as we make the codebase type-safe. They are aligned with +Python 3.13+ and current typing guidance. + +- Use builtin generics and PEP 604 unions: `list[str]`, `dict[str, int]`, `X | None`. +- For unused features or parameters, check call sites/usage before removal and confirm with the + user before deleting behavior, even if it appears unused. +- For command parsing, prefer a discriminated `TypedDict` union (e.g., `kind` field) over ad-hoc + nested dicts. If the command surface is shared across modules or grows over time, keep the + payload types in a small dedicated module; otherwise colocate them with the parser. +- For dynamic/optional collection attributes, narrow to `collections.abc.Collection` or `Sized` + when you only need `len()` or membership; avoid materializing unless multiple passes or sorting + is required. Exclude `str`/`bytes` when treating a value as a general collection. +- Use `getattr` only for truly dynamic attributes (plugin/duck-typed objects); immediately narrow + the result with `isinstance`/helper guards and avoid masking real missing attributes. +- For core agent types with concrete classes, prefer `isinstance` narrowing against those classes + over `getattr`/duck-typing to keep behavior explicit and enforceable by the type checker. +- Avoid `hasattr` checks when accessing attributes that all implementations share. Instead, add + the attribute to the Protocol so the type checker can verify access. This eliminates dynamic + lookups and makes the interface explicit. +- When accepting an `Iterable` but needing multiple passes or sorting, materialize to `list` + once at the boundary to avoid exhausting generators. +- Annotate public APIs and module boundaries first (CLI entry points, FastAPI routes, shared utils). +- Avoid `Any` unless crossing untyped boundaries; when unavoidable, localize it and add a comment. +- Prefer `TypedDict` or `Protocol` over loose `dict[str, object]` and `Any` for structured data. +- Use `Literal` or `Enum` for fixed choices; use `Final` for constants. +- Prefer `collections.abc` types for inputs (`Sequence`, `Mapping`, `Iterable`) and concrete types + for outputs (e.g., `list`, `dict`) when callers rely on mutability. +- If a capability is optional, pick a single shape (property or method) and document it in the + protocol; avoid supporting both unless required by existing implementations. +- For pydantic models, prefer explicit field types and `Annotated[...]` where validation metadata + is needed. +- Use `Self` for fluent APIs and `TypeAlias` for complex, reused types. +- Use `type: ignore` only when interacting with third-party APIs that are untyped or known-broken; + otherwise prefer `ty: ignore[rule]` with the specific rule. + +## Reproducible Plan + +1. **Baseline**: run `ty check` on `src/fast_agent` and capture the initial error set. +2. **Triage**: group issues by module and rule; fix the highest-signal errors first. +3. **Configure**: set rule levels in `pyproject.toml` so low-signal rules are `warn` while we + converge; keep `possibly-missing-attribute` and `possibly-missing-import` at `error`. +4. **Annotate**: add types to public APIs, then internal helpers, then tests. +5. **Refine**: replace broad `Any` or `object` with `TypedDict`, `Protocol`, `Literal`, or + `Enum` as appropriate. +6. **Suppress sparingly**: use `# ty: ignore[rule]` only when the type system cannot express a + valid pattern; include a short reason. +7. **Enforce**: add `ty check` to CI once warnings are near-zero; tighten rules to `error` as + we converge. + +## Decision Log (initial) + +- We will use `ty: ignore[rule]` over bare `ty: ignore` and avoid `type: ignore` unless an external + dependency forces it. +- We will prefer modern syntax (`X | Y`, builtin generics) given Python 3.13+. +- We will represent parsed UI commands as a lightweight discriminated `TypedDict` union (with a + `kind` field) in a dedicated module, while leaving free-form user input as plain `str`. From 2f3d443a943579bdac6f9be5f324396c3bf74b6e Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 10:29:51 +0000 Subject: [PATCH 02/15] imports --- src/fast_agent/acp/server/agent_acp_server.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/fast_agent/acp/server/agent_acp_server.py b/src/fast_agent/acp/server/agent_acp_server.py index bf60c180b..6012b97f1 100644 --- a/src/fast_agent/acp/server/agent_acp_server.py +++ b/src/fast_agent/acp/server/agent_acp_server.py @@ -50,6 +50,13 @@ from fast_agent.acp.content_conversion import convert_acp_prompt_to_mcp_content_blocks from fast_agent.acp.filesystem_runtime import ACPFilesystemRuntime from fast_agent.acp.permission_store import PermissionStore +from fast_agent.acp.protocols import ( + FilesystemRuntimeCapable, + InstructionContextCapable, + PlanTelemetryCapable, + ShellRuntimeCapable, + WorkflowTelemetryCapable, +) from fast_agent.acp.slash_commands import SlashCommandHandler from fast_agent.acp.terminal_runtime import ACPTerminalRuntime from fast_agent.acp.tool_permission_adapter import ACPToolPermissionAdapter @@ -72,6 +79,7 @@ from fast_agent.llm.stream_types import StreamChunk from fast_agent.llm.usage_tracking import last_turn_usage from fast_agent.mcp.helpers.content_helpers import is_text_content +from fast_agent.mcp.types import McpAgentProtocol from fast_agent.types import LlmStopReason, PromptMessageExtended, RequestParams from fast_agent.workflow_telemetry import ACPPlanTelemetryProvider, ToolHandlerWorkflowTelemetry @@ -448,15 +456,8 @@ def _build_session_modes( # Create a SessionMode for each agent for agent_name, agent in instance.agents.items(): - # Get instruction from agent's config - instruction = "" - resolved_instruction = resolved_cache.get(agent_name) - if resolved_instruction: - instruction = resolved_instruction - elif hasattr(agent, "_config") and hasattr(agent._config, "instruction"): - instruction = agent._config.instruction - elif hasattr(agent, "instruction"): - instruction = agent.instruction + # Get instruction from resolved cache (if available) or agent's instruction + instruction = resolved_cache.get(agent_name) or agent.instruction # Format description (first line, truncated to 200 chars) description = truncate_description(instruction) if instruction else None @@ -606,8 +607,8 @@ async def new_session( # Register tool handler with agents' aggregators for agent_name, agent in instance.agents.items(): - if hasattr(agent, "_aggregator"): - aggregator = agent._aggregator + if isinstance(agent, McpAgentProtocol): + aggregator = agent.aggregator aggregator._tool_handler = tool_handler logger.info( @@ -617,11 +618,11 @@ async def new_session( agent_name=agent_name, ) - if hasattr(agent, "workflow_telemetry"): + if isinstance(agent, WorkflowTelemetryCapable): agent.workflow_telemetry = workflow_telemetry # Set up plan telemetry for agents that support it (e.g., IterativePlanner) - if hasattr(agent, "plan_telemetry"): + if isinstance(agent, PlanTelemetryCapable): plan_telemetry = ACPPlanTelemetryProvider(self._connection, session_id) agent.plan_telemetry = plan_telemetry logger.info( @@ -667,8 +668,8 @@ async def new_session( # Register permission handler with all agents' aggregators for agent_name, agent in instance.agents.items(): - if hasattr(agent, "_aggregator"): - aggregator = agent._aggregator + if isinstance(agent, McpAgentProtocol): + aggregator = agent.aggregator aggregator._permission_handler = permission_handler logger.info( @@ -691,7 +692,7 @@ async def new_session( # Check if any agent has shell runtime enabled for agent_name, agent in instance.agents.items(): if ( - hasattr(agent, "_shell_runtime_enabled") + isinstance(agent, ShellRuntimeCapable) and agent._shell_runtime_enabled ): # Create ACPTerminalRuntime for this session @@ -703,7 +704,7 @@ async def new_session( session_id=session_id, activation_reason="via ACP terminal support", timeout_seconds=getattr( - agent._shell_runtime, "timeout_seconds", 90 + agent._shell_runtime, "timeout_seconds", 90 # ty: ignore[unresolved-attribute] ), tool_handler=tool_handler, default_output_byte_limit=default_limit, @@ -711,17 +712,16 @@ async def new_session( ) # Inject into agent - if hasattr(agent, "set_external_runtime"): - agent.set_external_runtime(terminal_runtime) - session_state.terminal_runtime = terminal_runtime + agent.set_external_runtime(terminal_runtime) + session_state.terminal_runtime = terminal_runtime - logger.info( - "ACP terminal runtime injected", - name="acp_terminal_injected", - session_id=session_id, - agent_name=agent_name, - default_output_limit=default_limit, - ) + logger.info( + "ACP terminal runtime injected", + name="acp_terminal_injected", + session_id=session_id, + agent_name=agent_name, + default_output_limit=default_limit, + ) # If client supports filesystem operations, inject ACP filesystem runtime if self._client_supports_fs_read or self._client_supports_fs_write: @@ -741,7 +741,7 @@ async def new_session( # Inject filesystem runtime into each agent for agent_name, agent in instance.agents.items(): - if hasattr(agent, "set_filesystem_runtime"): + if isinstance(agent, FilesystemRuntimeCapable): agent.set_filesystem_runtime(filesystem_runtime) logger.info( "ACP filesystem runtime injected", @@ -774,7 +774,7 @@ async def new_session( # Set session context on agents that have InstructionBuilder # This ensures {{env}}, {{workspaceRoot}}, etc. are available when rebuilding for agent_name, agent in instance.agents.items(): - if hasattr(agent, "set_instruction_context"): + if isinstance(agent, InstructionContextCapable): try: agent.set_instruction_context(session_context) except Exception as e: From 4eb14bb2a04e296b85deff3259b9e14b50b25473 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 11:18:38 +0000 Subject: [PATCH 03/15] typesafe --- src/fast_agent/core/direct_decorators.py | 14 ++- src/fast_agent/core/fastagent.py | 109 +++++++++++++++-------- 2 files changed, 85 insertions(+), 38 deletions(-) diff --git a/src/fast_agent/core/direct_decorators.py b/src/fast_agent/core/direct_decorators.py index 96dd7b6af..ccb4e9cd8 100644 --- a/src/fast_agent/core/direct_decorators.py +++ b/src/fast_agent/core/direct_decorators.py @@ -275,7 +275,12 @@ def agent( tools: dict[str, list[str]] | None = None, resources: dict[str, list[str]] | None = None, prompts: dict[str, list[str]] | None = None, - skills: SkillManifest | SkillRegistry | Path | str | None = None, + skills: SkillManifest + | SkillRegistry + | Path + | str + | list[SkillManifest | SkillRegistry | Path | str | None] + | None = None, model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, @@ -353,7 +358,12 @@ def custom( tools: dict[str, list[str]] | None = None, resources: dict[str, list[str]] | None = None, prompts: dict[str, list[str]] | None = None, - skills: SkillManifest | SkillRegistry | Path | str | None = None, + skills: SkillManifest + | SkillRegistry + | Path + | str + | list[SkillManifest | SkillRegistry | Path | str | None] + | None = None, model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, diff --git a/src/fast_agent/core/fastagent.py b/src/fast_agent/core/fastagent.py index 8a997b978..5bd919855 100644 --- a/src/fast_agent/core/fastagent.py +++ b/src/fast_agent/core/fastagent.py @@ -19,15 +19,16 @@ TYPE_CHECKING, Any, AsyncIterator, - Awaitable, Callable, Literal, ParamSpec, Sequence, + TypeAlias, TypeVar, ) import yaml +import yaml.parser from opentelemetry import trace from fast_agent import config @@ -101,6 +102,7 @@ F = TypeVar("F", bound=Callable[..., Any]) # For decorated functions logger = get_logger(__name__) +SkillEntry: TypeAlias = SkillManifest | SkillRegistry | Path | str class FastAgent: @@ -368,7 +370,12 @@ def agent( tools: dict[str, list[str]] | None = None, resources: dict[str, list[str]] | None = None, prompts: dict[str, list[str]] | None = None, - skills: list[SkillManifest | SkillRegistry | Path | str | None] | None = None, + skills: SkillManifest + | SkillRegistry + | Path + | str + | list[SkillManifest | SkillRegistry | Path | str | None] + | None = None, model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, @@ -390,11 +397,17 @@ def custom( name: str = "default", instruction_or_kwarg: str | Path | AnyUrl | None = None, *, - instruction: str | Path | AnyUrl = DEFAULT_AGENT_INSTRUCTION, + instruction: str | Path | AnyUrl = "You are a helpful agent.", servers: list[str] = [], tools: dict[str, list[str]] | None = None, resources: dict[str, list[str]] | None = None, prompts: dict[str, list[str]] | None = None, + skills: SkillManifest + | SkillRegistry + | Path + | str + | list[SkillManifest | SkillRegistry | Path | str | None] + | None = None, model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, @@ -402,7 +415,9 @@ def custom( default: bool = False, elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def orchestrator( self, @@ -420,7 +435,9 @@ def orchestrator( plan_iterations: int = 5, default: bool = False, api_key: str | None = None, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def iterative_planner( self, @@ -433,7 +450,9 @@ def iterative_planner( plan_iterations: int = -1, default: bool = False, api_key: str | None = None, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def router( self, @@ -452,7 +471,9 @@ def router( default: bool = False, elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def chain( self, @@ -462,7 +483,9 @@ def chain( instruction: str | Path | AnyUrl | None = None, cumulative: bool = False, default: bool = False, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def parallel( self, @@ -473,7 +496,9 @@ def parallel( instruction: str | Path | AnyUrl | None = None, include_request: bool = True, default: bool = False, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def evaluator_optimizer( self, @@ -486,7 +511,9 @@ def evaluator_optimizer( max_refinements: int = 3, refinement_instruction: str | None = None, default: bool = False, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... def maker( self, @@ -499,18 +526,21 @@ def maker( red_flag_max_length: int | None = None, instruction: str | Path | AnyUrl | None = None, default: bool = False, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: ... # Runtime bindings (actual implementations) - agent = agent_decorator - custom = custom_decorator - orchestrator = orchestrator_decorator - iterative_planner = orchestrator2_decorator - router = router_decorator - chain = chain_decorator - parallel = parallel_decorator - evaluator_optimizer = evaluator_optimizer_decorator - maker = maker_decorator + if not TYPE_CHECKING: + agent = agent_decorator + custom = custom_decorator + orchestrator = orchestrator_decorator + iterative_planner = orchestrator2_decorator + router = router_decorator + chain = chain_decorator + parallel = parallel_decorator + evaluator_optimizer = evaluator_optimizer_decorator + maker = maker_decorator def _get_acp_server_class(self): """Import and return the ACP server class with helpful error handling.""" @@ -548,12 +578,13 @@ async def run(self) -> AsyncIterator["AgentApp"]: cli_model_override = getattr(self.args, "model", None) # Store the model source for UI display + config = self.context.config model_source = get_default_model_source( - config_default_model=self.context.config.default_model, + config_default_model=config.default_model if config else None, cli_model=cli_model_override, ) - if self.context.config: - self.context.config.model_source = model_source # type: ignore[attr-defined] + if config: + config.model_source = model_source # type: ignore[attr-defined] tracer = trace.get_tracer(__name__) with tracer.start_as_current_span(self.name): @@ -892,7 +923,7 @@ def _apply_instruction_context( continue resolved = apply_template_variables(template, context_vars) - if resolved == template: + if resolved is None or resolved == template: continue # Use set_instruction() which handles syncing request_params and LLM @@ -919,27 +950,33 @@ def _apply_skills_to_agent_configs(self, default_skills: list[SkillManifest]) -> def _resolve_skills( self, - entry: SkillManifest - | SkillRegistry - | Path - | str - | list[SkillManifest | SkillRegistry | Path | str | None] - | None, + entry: SkillEntry | list[SkillEntry | None] | None, ) -> list[SkillManifest]: if entry is None: return [] if isinstance(entry, list): - filtered = [item for item in entry if item is not None] + filtered: list[SkillEntry] = [] + for item in entry: + if isinstance(item, (SkillManifest, SkillRegistry, Path, str)): + filtered.append(item) + elif item is not None: + logger.debug( + "Unsupported skill entry type", + data={"type": type(item).__name__}, + ) if not filtered: return [] - if all(isinstance(item, (Path, str)) for item in filtered): - directories = [ - Path(item) if isinstance(item, str) else item for item in filtered - ] + directory_entries = [ + item for item in filtered if isinstance(item, (Path, str)) + ] + if len(directory_entries) == len(filtered): + directories: list[Path | str] = [] + for item in directory_entries: + directories.append(Path(item) if isinstance(item, str) else item) registry = SkillRegistry(base_dir=Path.cwd(), directories=directories) return registry.load_manifests() manifests: list[SkillManifest] = [] - for item in entry: + for item in filtered: manifests.extend(self._resolve_skills(item)) return manifests if isinstance(entry, SkillManifest): From 4d46d41b862e1f87bed43433b5f52f7cf36b64bd Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 11:40:35 +0000 Subject: [PATCH 04/15] refactor bedrock, pyproject --- pyproject.toml | 11 +-- .../llm/provider/bedrock/llm_bedrock.py | 98 ++++++++++--------- 2 files changed, 53 insertions(+), 56 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fcfd541a1..05878383f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,16 +64,6 @@ all-providers = [ "boto3>=1.35.0", "tensorzero>=2025.7.5" ] -dev = [ - "pre-commit>=4.0.1", - "pydantic>=2.10.4", - "pyyaml>=6.0.2", - "ruff>=0.8.4", - "pytest>=7.4.0", - "pytest-asyncio>=0.21.1", - "pytest-cov", - "ruamel.yaml>=0.18.0", -] [build-system] requires = ["hatchling"] @@ -117,6 +107,7 @@ testpaths = ["tests"] [dependency-groups] dev = [ + "boto3>=1.35.0", "pre-commit>=4.0.1", "pydantic>=2.10.4", "ruamel.yaml>=0.18.0", diff --git a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py index b0194cd4f..ab4b5749d 100644 --- a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py @@ -36,8 +36,12 @@ from mcp import ListToolsResult try: - import boto3 - from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError + import boto3 # ty: ignore[unresolved-import] + from botocore.exceptions import ( # ty: ignore[unresolved-import] + BotoCoreError, + ClientError, + NoCredentialsError, + ) except ImportError: boto3 = None BotoCoreError = Exception @@ -197,8 +201,11 @@ def __init__(self, *args, **kwargs) -> None: # Extract AWS configuration from kwargs first self.aws_region = kwargs.pop("region", None) self.aws_profile = kwargs.pop("profile", None) + kwargs.pop("provider", None) + if args and isinstance(args[0], Provider): + args = args[1:] - super().__init__(*args, provider=Provider.BEDROCK, **kwargs) + super().__init__(Provider.BEDROCK, *args, **kwargs) # Use config values if not provided in kwargs (after super().__init__) if self.context.config and self.context.config.bedrock: @@ -250,7 +257,7 @@ def _initialize_default_params(self, kwargs: dict) -> RequestParams: @property def model(self) -> str: """Get the model name, guaranteed to be set.""" - return self.default_request_params.model + return self.default_request_params.model or DEFAULT_BEDROCK_MODEL def _get_bedrock_client(self): """Get or create Bedrock client.""" @@ -890,7 +897,8 @@ def _convert_multipart_to_bedrock_message( Returns: Bedrock message parameter dictionary """ - bedrock_msg = {"role": msg.role, "content": []} + content_blocks: list[dict[str, Any]] = [] + bedrock_msg = {"role": msg.role, "content": content_blocks} # Handle tool results first (if present) if msg.tool_results: @@ -920,7 +928,7 @@ def _convert_multipart_to_bedrock_message( if tool_result_parts: full_result_text = f"Tool Results:\n{', '.join(tool_result_parts)}" - bedrock_msg["content"].append({"type": "text", "text": full_result_text}) + content_blocks.append({"type": "text", "text": full_result_text}) else: # For Nova/Anthropic models: use structured tool_result format for tool_id, tool_result in msg.tool_results.items(): @@ -933,7 +941,7 @@ def _convert_multipart_to_bedrock_message( if not result_content_blocks: result_content_blocks.append({"text": "[No content in tool result]"}) - bedrock_msg["content"].append( + content_blocks.append( { "type": "tool_result", "tool_use_id": tool_id, @@ -945,7 +953,7 @@ def _convert_multipart_to_bedrock_message( # Handle regular content for content_item in msg.content: if isinstance(content_item, TextContent): - bedrock_msg["content"].append({"type": "text", "text": content_item.text}) + content_blocks.append({"type": "text", "text": content_item.text}) return bedrock_msg @@ -1134,8 +1142,11 @@ async def _process_stream( # Construct the response message full_text = "".join(response_content) + response_content_items: list[dict[str, Any]] = ( + [{"text": full_text}] if full_text else [] + ) response = { - "content": [{"text": full_text}] if full_text else [], + "content": response_content_items, "stop_reason": stop_reason or "end_turn", "usage": { "input_tokens": usage.get("inputTokens", 0), @@ -1174,7 +1185,7 @@ async def _process_stream( # Clean up the accumulator del tool_use["toolUse"]["_input_accumulator"] - response["content"].extend(tool_uses) + response_content_items.extend(tool_uses) return response @@ -1249,30 +1260,16 @@ async def _bedrock_completion( else: messages.append(message_param) - # Get available tools (no resolver gating; fallback logic will decide wiring) + # Tools are provided by the caller (aligned with other providers) tool_list = None - - try: - tool_list = await self.aggregator.list_tools() - self.logger.debug(f"Found {len(tool_list.tools)} MCP tools") - except Exception as e: - self.logger.error(f"Error fetching MCP tools: {e}") - import traceback - - self.logger.debug(f"Traceback: {traceback.format_exc()}") - tool_list = None - - # Use tools parameter if provided, otherwise get from aggregator - if tools is None: - tools = tool_list.tools if tool_list else [] - elif tool_list is None and tools: + if tools: # Create a ListToolsResult from the provided tools for conversion from mcp.types import ListToolsResult tool_list = ListToolsResult(tools=tools) response_content_blocks: list[ContentBlock] = [] - model = self.default_request_params.model + model = self.default_request_params.model or DEFAULT_BEDROCK_MODEL # Single API call - no tool execution loop self._log_chat_progress(self.chat_turn(), model=model) @@ -1322,12 +1319,15 @@ async def _bedrock_completion( # Track whether we changed system mode cache this turn tried_system_fallback = False - processed_response = None # type: ignore[assignment] + processed_response: dict[str, Any] | None = None last_error_msg = None for schema_choice in schema_order: # Fresh messages per attempt - converse_args = {"modelId": model, "messages": [dict(m) for m in bedrock_messages]} + converse_args: dict[str, Any] = { + "modelId": model, + "messages": [dict(m) for m in bedrock_messages], + } # Build tools representation for this schema tools_payload: Union[list[dict[str, Any]], str, None] = None @@ -1586,9 +1586,11 @@ async def _bedrock_completion( # Apply temperature now that reasoning is disabled if params.temperature is not None: - if "inferenceConfig" not in converse_args: - converse_args["inferenceConfig"] = {} - converse_args["inferenceConfig"]["temperature"] = params.temperature + retry_inference_config = converse_args.get("inferenceConfig") + if not isinstance(retry_inference_config, dict): + retry_inference_config = {} + converse_args["inferenceConfig"] = retry_inference_config + retry_inference_config["temperature"] = params.temperature # Retry the API call if not use_streaming: @@ -1674,7 +1676,7 @@ async def _bedrock_completion( # Retry the same schema immediately in inject mode try: # Rebuild messages for inject - converse_args = { + converse_args: dict[str, Any] = { "modelId": model, "messages": [dict(m) for m in bedrock_messages], } @@ -1750,15 +1752,16 @@ async def _bedrock_completion( } # Track usage - if processed_response.get("usage"): + usage = processed_response.get("usage") if processed_response else None + if isinstance(usage, dict): try: - usage = processed_response["usage"] turn_usage = TurnUsage( - provider=Provider.BEDROCK.value, + provider=Provider.BEDROCK, model=model, - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0), + input_tokens=int(usage.get("input_tokens", 0) or 0), + output_tokens=int(usage.get("output_tokens", 0) or 0), + total_tokens=int(usage.get("input_tokens", 0) or 0) + + int(usage.get("output_tokens", 0) or 0), raw_usage=usage, ) self.usage_accumulator.add_turn(turn_usage) @@ -1772,9 +1775,10 @@ async def _bedrock_completion( messages.append(response_message_param) # Extract text content for responses - if processed_response.get("content"): - for content_item in processed_response["content"]: - if content_item.get("text"): + content_items = processed_response.get("content") if processed_response else None + if isinstance(content_items, list): + for content_item in content_items: + if isinstance(content_item, dict) and content_item.get("text"): response_content_blocks.append( TextContent(type="text", text=content_item["text"]) ) @@ -1820,7 +1824,8 @@ async def _bedrock_completion( pass # Handle different stop reasons - stop_reason = processed_response.get("stop_reason", "end_turn") + stop_reason_value = processed_response.get("stop_reason", "end_turn") + stop_reason = stop_reason_value if isinstance(stop_reason_value, str) else "end_turn" # Determine if we should parse for system-prompt tool calls (unified capabilities) caps_tmp = self.capabilities.get(model) or ModelCapabilities() @@ -1830,7 +1835,7 @@ async def _bedrock_completion( if stop_reason == "end_turn" and tools: # Only parse for tools if text contains actual function call structure message_text = "" - for content_item in processed_response.get("content", []): + for content_item in content_items or []: if isinstance(content_item, dict) and "text" in content_item: message_text += content_item.get("text", "") @@ -1955,8 +1960,9 @@ def get_field_type_representation(field_type: Any) -> Any: def _generate_schema_dict(model_class: Type) -> dict[str, Any]: """Recursively generate the schema as a dictionary.""" schema_dict = {} - if hasattr(model_class, "model_fields"): - for field_name, field_info in model_class.model_fields.items(): + model_fields = getattr(model_class, "model_fields", None) + if isinstance(model_fields, dict): + for field_name, field_info in model_fields.items(): schema_dict[field_name] = get_field_type_representation(field_info.annotation) return schema_dict From d46a6f04f681990f505738157bfa224f5b2db62c Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:07:30 +0000 Subject: [PATCH 05/15] type safety --- examples/tool-use-agent/agent.py | 4 +- .../src/hf_inference_acp/agents.py | 30 +- src/fast_agent/acp/protocols.py | 3 + src/fast_agent/acp/server/agent_acp_server.py | 25 +- src/fast_agent/acp/slash_commands.py | 2 +- src/fast_agent/agents/mcp_agent.py | 29 +- src/fast_agent/agents/tool_agent.py | 14 +- .../agents/workflow/router_agent.py | 14 +- src/fast_agent/cli/commands/go.py | 59 ++- src/fast_agent/core/fastagent.py | 8 +- src/fast_agent/core/prompt_templates.py | 10 +- src/fast_agent/interfaces.py | 2 + src/fast_agent/llm/hf_inference_lookup.py | 122 +++++- .../multipart_converter_anthropic.py | 2 +- src/fast_agent/mcp/helpers/content_helpers.py | 12 +- src/fast_agent/mcp/mcp_aggregator.py | 168 +++++--- src/fast_agent/mcp/mcp_connection_manager.py | 7 +- src/fast_agent/mcp/mcp_content.py | 74 +++- src/fast_agent/mcp/prompt_serialization.py | 6 +- src/fast_agent/mcp/prompts/prompt_server.py | 33 +- src/fast_agent/mcp/prompts/prompt_template.py | 17 +- src/fast_agent/mcp/resource_utils.py | 47 ++- src/fast_agent/mcp/server/agent_server.py | 149 ++++--- src/fast_agent/types/__init__.py | 6 + src/fast_agent/types/tool_timing.py | 11 + src/fast_agent/ui/command_payloads.py | 6 +- src/fast_agent/ui/enhanced_prompt.py | 8 +- src/fast_agent/ui/interactive_prompt.py | 372 +++++++++--------- src/fast_agent/ui/mcp_display.py | 5 +- src/fast_agent/ui/mcp_ui_utils.py | 9 +- src/fast_agent/ui/tool_display.py | 4 +- .../acp/test_acp_skills_manager.py | 63 ++- .../acp/test_acp_slash_commands.py | 3 + .../acp/test_set_model_validation.py | 221 +++++++++++ tests/integration/api/test_prompt_commands.py | 34 +- .../llm/test_hf_inference_lookup_unit.py | 183 +++++++++ uv.lock | 22 +- 37 files changed, 1241 insertions(+), 543 deletions(-) create mode 100644 src/fast_agent/types/tool_timing.py create mode 100644 tests/integration/acp/test_set_model_validation.py create mode 100644 tests/unit/fast_agent/llm/test_hf_inference_lookup_unit.py diff --git a/examples/tool-use-agent/agent.py b/examples/tool-use-agent/agent.py index 95b19b672..82420e309 100644 --- a/examples/tool-use-agent/agent.py +++ b/examples/tool-use-agent/agent.py @@ -26,7 +26,9 @@ def __init__( @fast.custom(CustomToolAgent) async def main() -> None: async with fast.run() as agent: - await agent.default.generate("What is the topic of the video call no.1234?") + await agent.default.generate( + "What is the topic of the video call no.1234?", + ) if __name__ == "__main__": diff --git a/publish/hf-inference-acp/src/hf_inference_acp/agents.py b/publish/hf-inference-acp/src/hf_inference_acp/agents.py index 547aac01d..0ba57d750 100644 --- a/publish/hf-inference-acp/src/hf_inference_acp/agents.py +++ b/publish/hf-inference-acp/src/hf_inference_acp/agents.py @@ -220,6 +220,7 @@ def acp_session_commands_allowlist(self) -> set[str]: async def _handle_set_model(self, arguments: str) -> str: """Handler for /set-model command.""" + from fast_agent.llm.hf_inference_lookup import validate_hf_model from fast_agent.llm.model_factory import ModelFactory model = arguments.strip() @@ -229,20 +230,24 @@ async def _handle_set_model(self, arguments: str) -> str: # Normalize the model string (auto-add hf. prefix if needed) model = _normalize_hf_model(model) - # Validate the model string before saving to config + # Validate the model string format try: ModelFactory.parse_model_string(model) except Exception as e: return f"Error: Invalid model `{model}` - {e}" - # Look up inference providers for this model - provider_info = await _lookup_and_format_providers(model) + # Validate model exists on HuggingFace and has providers + validation = await validate_hf_model(model, aliases=ModelFactory.MODEL_ALIASES) + if not validation.valid: + return validation.error or "Error: Model validation failed" try: update_model_in_config(model) applied = await self._apply_model_to_running_hf_agent(model) applied_note = "\n\nApplied to the running Hugging Face agent." if applied else "" - provider_prefix = f"{provider_info}\n\n" if provider_info else "" + provider_prefix = ( + f"{validation.display_message}\n\n" if validation.display_message else "" + ) return ( f"{provider_prefix}" f"Default model set to: `{model}`\n\nConfig file updated: `{CONFIG_FILE}`" @@ -498,9 +503,7 @@ async def _send_connect_update( await _send_connect_update(title="Connected", status="in_progress") # Rebuild system prompt to include fresh server instructions - await _send_connect_update( - title="Rebuilding system prompt…", status="in_progress" - ) + await _send_connect_update(title="Rebuilding system prompt…", status="in_progress") await self.rebuild_instruction_templates() # Get available tools @@ -543,6 +546,7 @@ async def _send_connect_update( async def _handle_set_model(self, arguments: str) -> str: """Handler for /set-model in Hugging Face mode.""" + from fast_agent.llm.hf_inference_lookup import validate_hf_model from fast_agent.llm.model_factory import ModelFactory model = arguments.strip() @@ -552,20 +556,24 @@ async def _handle_set_model(self, arguments: str) -> str: # Normalize the model string (auto-add hf. prefix if needed) model = _normalize_hf_model(model) - # Validate the model string before applying + # Validate the model string format try: ModelFactory.parse_model_string(model) except Exception as e: return f"Error: Invalid model `{model}` - {e}" - # Look up inference providers for this model - provider_info = await _lookup_and_format_providers(model) + # Validate model exists on HuggingFace and has providers + validation = await validate_hf_model(model, aliases=ModelFactory.MODEL_ALIASES) + if not validation.valid: + return validation.error or "Error: Model validation failed" try: # Apply model first - if this fails, don't update config await self.apply_model(model) update_model_in_config(model) - provider_prefix = f"{provider_info}\n\n" if provider_info else "" + provider_prefix = ( + f"{validation.display_message}\n\n" if validation.display_message else "" + ) return f"{provider_prefix}Active model set to: `{model}`\n\nConfig file updated: `{CONFIG_FILE}`" except Exception as e: return f"Error setting model: {e}" diff --git a/src/fast_agent/acp/protocols.py b/src/fast_agent/acp/protocols.py index 06651de04..4398a36ac 100644 --- a/src/fast_agent/acp/protocols.py +++ b/src/fast_agent/acp/protocols.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from fast_agent.acp.filesystem_runtime import ACPFilesystemRuntime from fast_agent.acp.terminal_runtime import ACPTerminalRuntime + from fast_agent.tools.shell_runtime import ShellRuntime from fast_agent.workflow_telemetry import PlanTelemetryProvider, WorkflowTelemetryProvider @@ -19,6 +20,8 @@ class ShellRuntimeCapable(Protocol): """Agent that supports external shell runtime injection.""" + _shell_runtime: "ShellRuntime" + @property def _shell_runtime_enabled(self) -> bool: ... diff --git a/src/fast_agent/acp/server/agent_acp_server.py b/src/fast_agent/acp/server/agent_acp_server.py index 6012b97f1..325254608 100644 --- a/src/fast_agent/acp/server/agent_acp_server.py +++ b/src/fast_agent/acp/server/agent_acp_server.py @@ -68,13 +68,14 @@ TERMINAL_OUTPUT_TOKEN_HEADROOM_RATIO, TERMINAL_OUTPUT_TOKEN_RATIO, ) +from fast_agent.context import Context from fast_agent.core.fastagent import AgentInstance from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt_templates import ( apply_template_variables, enrich_with_environment_context, ) -from fast_agent.interfaces import ACPAwareProtocol, StreamingAgentProtocol +from fast_agent.interfaces import ACPAwareProtocol, AgentProtocol, StreamingAgentProtocol from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.stream_types import StreamChunk from fast_agent.llm.usage_tracking import last_turn_usage @@ -826,27 +827,15 @@ async def new_session( # Set ACPContext on each agent's Context object (if they have one) for agent_name, agent in instance.agents.items(): - if hasattr(agent, "_context") and agent._context is not None: - agent._context.acp = acp_context + context = getattr(agent, "context", None) + if isinstance(context, Context): + context.acp = acp_context logger.debug( "ACPContext set on agent", name="acp_context_set", session_id=session_id, agent_name=agent_name, ) - elif hasattr(agent, "context"): - # Try via context property - try: - agent.context.acp = acp_context - logger.debug( - "ACPContext set on agent via context property", - name="acp_context_set", - session_id=session_id, - agent_name=agent_name, - ) - except Exception: - # Agent may not have a context available - pass logger.info( "ACPContext created for session", @@ -1186,7 +1175,7 @@ def on_stream_chunk(chunk: StreamChunk): agent, session_state ) turn_start_index = None - if getattr(agent, "usage_accumulator", None) is not None: + if isinstance(agent, AgentProtocol) and agent.usage_accumulator is not None: turn_start_index = len(agent.usage_accumulator.turns) result = await agent.generate( prompt_message, @@ -1328,7 +1317,7 @@ def on_stream_chunk(chunk: StreamChunk): # Return response with appropriate stop reason return PromptResponse( stop_reason=acp_stop_reason, - _meta=status_line_meta, + field_meta=status_line_meta, ) except asyncio.CancelledError: # Task was cancelled - return appropriate response diff --git a/src/fast_agent/acp/slash_commands.py b/src/fast_agent/acp/slash_commands.py index 0c0b368ce..9bf1250d4 100644 --- a/src/fast_agent/acp/slash_commands.py +++ b/src/fast_agent/acp/slash_commands.py @@ -980,7 +980,7 @@ async def _refresh_agent_skills(self, agent: AgentProtocol) -> None: await rebuild_agent_instruction( agent, skill_manifests=manifests, - instruction_context=instruction_context, + context=instruction_context, skill_registry=registry, ) diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index 3067483e2..f61bf5265 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -55,7 +55,12 @@ ) from fast_agent.tools.shell_runtime import ShellRuntime from fast_agent.tools.skill_reader import SkillReader -from fast_agent.types import PromptMessageExtended, RequestParams +from fast_agent.types import ( + PromptMessageExtended, + RequestParams, + ToolTimingInfo, + ToolTimings, +) from fast_agent.ui import console # Define a TypeVar for models @@ -570,7 +575,7 @@ async def _call_human_input_tool( """ try: # Run via shared tool runner - resp_text = await run_elicitation_form(arguments, agent_name=self._name) + resp_text = await run_elicitation_form(arguments or {}, agent_name=self._name) if resp_text == "__DECLINED__": return CallToolResult( isError=False, @@ -806,7 +811,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend return PromptMessageExtended(role="user", tool_results={}) tool_results: dict[str, CallToolResult] = {} - tool_timings: dict[str, float] = {} # Track timing for each tool call + tool_timings: ToolTimings = {} # Track timing for each tool call tool_loop_error: str | None = None # Cache available tool names exactly as advertised to the LLM for display/highlighting @@ -848,11 +853,11 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend ) # Select display/highlight names - display_tool_name = ( - (namespaced_tool or candidate_namespaced_tool).namespaced_tool_name - if (namespaced_tool or candidate_namespaced_tool) is not None - else tool_name - ) + active_namespaced = namespaced_tool or candidate_namespaced_tool + if active_namespaced is not None: + display_tool_name = active_namespaced.namespaced_tool_name + else: + display_tool_name = tool_name # Check if tool is available from various sources is_external_runtime_tool = ( @@ -930,10 +935,10 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend tool_results[correlation_id] = result # Store timing and transport channel info - tool_timings[correlation_id] = { - "timing_ms": duration_ms, - "transport_channel": getattr(result, "transport_channel", None), - } + tool_timings[correlation_id] = ToolTimingInfo( + timing_ms=duration_ms, + transport_channel=getattr(result, "transport_channel", None), + ) # Show tool result (like ToolAgent does) skybridge_config = None diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index 1878b155a..e3c588fb8 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -14,7 +14,7 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.mcp.helpers.content_helpers import text_content from fast_agent.tools.elicitation import get_elicitation_fastmcp_tool -from fast_agent.types import PromptMessageExtended, RequestParams +from fast_agent.types import PromptMessageExtended, RequestParams, ToolTimingInfo, ToolTimings from fast_agent.types.llm_stop_reason import LlmStopReason logger = get_logger(__name__) @@ -140,7 +140,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend return PromptMessageExtended(role="user", tool_results={}) tool_results: dict[str, CallToolResult] = {} - tool_timings: dict[str, float] = {} # Track timing for each tool call + tool_timings: ToolTimings = {} # Track timing for each tool call tool_loop_error: str | None = None # TODO -- use gather() for parallel results, update display tool_schemas = (await self.list_tools()).tools @@ -184,10 +184,10 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend tool_results[correlation_id] = result # Store timing info (transport_channel not available for local tools) - tool_timings[correlation_id] = { - "timing_ms": duration_ms, - "transport_channel": None - } + tool_timings[correlation_id] = ToolTimingInfo( + timing_ms=duration_ms, + transport_channel=None, + ) self.display.show_tool_result(name=self.name, result=result, tool_name=tool_name, timing_ms=duration_ms) return self._finalize_tool_results(tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error) @@ -211,7 +211,7 @@ def _finalize_tool_results( self, tool_results: dict[str, CallToolResult], *, - tool_timings: dict[str, dict[str, float | str | None]] | None = None, + tool_timings: ToolTimings | None = None, tool_loop_error: str | None = None, ) -> PromptMessageExtended: import json diff --git a/src/fast_agent/agents/workflow/router_agent.py b/src/fast_agent/agents/workflow/router_agent.py index ca4808c33..1cae2f871 100644 --- a/src/fast_agent/agents/workflow/router_agent.py +++ b/src/fast_agent/agents/workflow/router_agent.py @@ -100,12 +100,18 @@ def __init__( self.agent_map = {agent.name: agent for agent in agents} # Set up base router request parameters with just the base instruction for now - base_params = {"systemPrompt": ROUTING_SYSTEM_INSTRUCTION, "use_history": False} - if default_request_params: - merged_params = default_request_params.model_copy(update=base_params) + merged_params = default_request_params.model_copy( + update={ + "systemPrompt": ROUTING_SYSTEM_INSTRUCTION, + "use_history": False, + } + ) else: - merged_params = RequestParams(**base_params) + merged_params = RequestParams( + systemPrompt=ROUTING_SYSTEM_INSTRUCTION, + use_history=False, + ) self._default_request_params = merged_params diff --git a/src/fast_agent/cli/commands/go.py b/src/fast_agent/cli/commands/go.py index 51e7af79f..39c1c6a18 100644 --- a/src/fast_agent/cli/commands/go.py +++ b/src/fast_agent/cli/commands/go.py @@ -5,7 +5,7 @@ import shlex import sys from pathlib import Path -from typing import Literal +from typing import Any, Literal, cast import typer @@ -112,8 +112,8 @@ async def _run_agent( model: str | None = None, message: str | None = None, prompt_file: str | None = None, - url_servers: dict[str, dict[str, str]] | None = None, - stdio_servers: dict[str, dict[str, str]] | None = None, + url_servers: dict[str, dict[str, Any]] | None = None, + stdio_servers: dict[str, dict[str, Any]] | None = None, agent_name: str | None = "agent", skills_directory: Path | None = None, shell_runtime: bool = False, @@ -130,18 +130,14 @@ async def _run_agent( # Create the FastAgent instance - fast_kwargs = { - "name": name, - "config_path": config_path, - "ignore_unknown_args": True, - "parse_cli_args": False, # Don't parse CLI args, we're handling it ourselves - } - if mode == "serve": - fast_kwargs["quiet"] = True - if skills_directory is not None: - fast_kwargs["skills_directory"] = skills_directory - - fast = FastAgent(**fast_kwargs) + fast = FastAgent( + name=name, + config_path=config_path, + ignore_unknown_args=True, + parse_cli_args=False, # Don't parse CLI args, we're handling it ourselves + quiet=mode == "serve", + skills_directory=skills_directory, + ) # Set model on args so model source detection works correctly if model: @@ -152,8 +148,10 @@ async def _run_agent( setattr(fast.app.context, "shell_runtime", True) # Add all dynamic servers to the configuration - await add_servers_to_config(fast, url_servers) - await add_servers_to_config(fast, stdio_servers) + if url_servers: + await add_servers_to_config(fast, cast("dict[str, dict[str, Any]]", url_servers)) + if stdio_servers: + await add_servers_to_config(fast, cast("dict[str, dict[str, Any]]", stdio_servers)) # Check if we have multiple models (comma-delimited) if model and "," in model: @@ -166,12 +164,12 @@ async def _run_agent( agent_name = f"{model_name}" # Define the agent with specified parameters - agent_kwargs = {"instruction": instruction, "name": agent_name} - if server_list: - agent_kwargs["servers"] = server_list - agent_kwargs["model"] = model_name - - @fast.agent(**agent_kwargs) + @fast.agent( + name=agent_name, + instruction=instruction, + servers=server_list or [], + model=model_name, + ) async def model_agent(): pass @@ -218,15 +216,12 @@ async def cli_agent(): else: # Single model - use original behavior # Define the agent with specified parameters - agent_kwargs = {"instruction": instruction} - if agent_name: - agent_kwargs["name"] = agent_name - if server_list: - agent_kwargs["servers"] = server_list - if model: - agent_kwargs["model"] = model - - @fast.agent(**agent_kwargs) + @fast.agent( + name=agent_name or "agent", + instruction=instruction, + servers=server_list or [], + model=model, + ) async def cli_agent(): async with fast.run() as agent: if message: diff --git a/src/fast_agent/core/fastagent.py b/src/fast_agent/core/fastagent.py index 5bd919855..59e8900f9 100644 --- a/src/fast_agent/core/fastagent.py +++ b/src/fast_agent/core/fastagent.py @@ -601,7 +601,13 @@ async def run(self) -> AsyncIterator["AgentApp"]: default_skills: list[SkillManifest] = [] if registry: - default_skills = registry.load_manifests() + try: + default_skills = registry.load_manifests() + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to load skills; continuing without them", + data={"error": str(exc)}, + ) self._apply_skills_to_agent_configs(default_skills) diff --git a/src/fast_agent/core/prompt_templates.py b/src/fast_agent/core/prompt_templates.py index 528e0b7d8..0feb9a003 100644 --- a/src/fast_agent/core/prompt_templates.py +++ b/src/fast_agent/core/prompt_templates.py @@ -9,9 +9,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Mapping, MutableMapping, Sequence +from fast_agent.core.logging.logger import get_logger + if TYPE_CHECKING: from fast_agent.skills import SkillManifest +logger = get_logger(__name__) + def apply_template_variables( template: str | None, variables: Mapping[str, str | None] | None @@ -134,7 +138,11 @@ def load_skills_for_context( override_dirs.append(base_dir / override_path) registry = SkillRegistry(base_dir=base_dir, directories=override_dirs) - return registry.load_manifests() + try: + return registry.load_manifests() + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to load skills; continuing without them", data={"error": str(exc)}) + return [] def enrich_with_environment_context( diff --git a/src/fast_agent/interfaces.py b/src/fast_agent/interfaces.py index ea3b4ea7a..621fdd376 100644 --- a/src/fast_agent/interfaces.py +++ b/src/fast_agent/interfaces.py @@ -96,6 +96,8 @@ def add_tool_stream_listener( self, listener: Callable[[str, dict[str, Any] | None], None] ) -> Callable[[], None]: ... + def chat_turn(self) -> int: ... + @property def message_history(self) -> list[PromptMessageExtended]: ... diff --git a/src/fast_agent/llm/hf_inference_lookup.py b/src/fast_agent/llm/hf_inference_lookup.py index ab4dfd0a6..e19ea690a 100644 --- a/src/fast_agent/llm/hf_inference_lookup.py +++ b/src/fast_agent/llm/hf_inference_lookup.py @@ -7,54 +7,53 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING import httpx +from pydantic import BaseModel, Field, computed_field if TYPE_CHECKING: + from collections.abc import Awaitable, Callable from typing import Any + # Type alias for lookup function - allows dependency injection for testing + InferenceLookupFn = Callable[[str], Awaitable["InferenceProviderLookupResult"]] -class InferenceProviderStatus(Enum): + +class InferenceProviderStatus(str, Enum): """Status of an inference provider for a model.""" LIVE = "live" STAGING = "staging" -@dataclass -class InferenceProvider: +class InferenceProvider(BaseModel): """Information about an inference provider for a model.""" name: str - status: InferenceProviderStatus - provider_id: str - task: str - is_model_author: bool + status: InferenceProviderStatus = InferenceProviderStatus.LIVE + provider_id: str = Field(default="", alias="providerId") + task: str = "" + is_model_author: bool = Field(default=False, alias="isModelAuthor") + + model_config = {"populate_by_name": True} @classmethod def from_dict(cls, name: str, data: dict[str, Any]) -> "InferenceProvider": """Create an InferenceProvider from API response data.""" - return cls( - name=name, - status=InferenceProviderStatus(data.get("status", "live")), - provider_id=data.get("providerId", ""), - task=data.get("task", ""), - is_model_author=data.get("isModelAuthor", False), - ) + return cls(name=name, **data) -@dataclass -class InferenceProviderLookupResult: +class InferenceProviderLookupResult(BaseModel): """Result of looking up inference providers for a model.""" model_id: str exists: bool - providers: list[InferenceProvider] + providers: list[InferenceProvider] = Field(default_factory=list) error: str | None = None + @computed_field # type: ignore[prop-decorator] @property def has_providers(self) -> bool: """Return True if the model has any live inference providers.""" @@ -77,18 +76,29 @@ def format_model_strings(self) -> list[str]: return [f"{self.model_id}:{p.name}" for p in self.live_providers] +class ModelValidationResult(BaseModel): + """Result of validating an HF model for /set-model.""" + + valid: bool + display_message: str = "" + error: str | None = None + + HF_API_BASE = "https://huggingface.co/api/models" async def lookup_inference_providers( model_id: str, timeout: float = 10.0, + *, + lookup_fn: InferenceLookupFn | None = None, ) -> InferenceProviderLookupResult: """Look up available inference providers for a HuggingFace model. Args: model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Thinking") timeout: Request timeout in seconds + lookup_fn: Optional function to use for lookup (for testing) Returns: InferenceProviderLookupResult with provider information @@ -100,6 +110,10 @@ async def lookup_inference_providers( ... for model_str in result.format_model_strings(): ... print(f" hf.{model_str}") """ + # Allow test injection + if lookup_fn is not None: + return await lookup_fn(model_id) + # Normalize model_id - strip any hf. prefix if model_id.startswith("hf."): model_id = model_id[3:] @@ -223,3 +237,75 @@ def format_inference_lookup_message(result: InferenceProviderLookupResult) -> st lines.append(f"- `hf.{model_str}`") return "\n".join(lines) + + +async def validate_hf_model( + model: str, + *, + aliases: dict[str, str] | None = None, + lookup_fn: InferenceLookupFn | None = None, +) -> ModelValidationResult: + """Validate that an HF model exists and has inference providers. + + Args: + model: The model string (e.g., "hf.moonshotai/Kimi-K2-Thinking:together") + Can also be an alias like "kimi" or "glm" that resolves to an HF model. + aliases: Optional dict of model aliases (e.g., {"kimi": "hf.moonshotai/..."}). + If not provided, no alias resolution is performed. + lookup_fn: Optional function to use for lookup (for testing) + + Returns: + ModelValidationResult with validation status and messages + """ + import random + + # Resolve aliases first (e.g., "kimi" -> "hf.moonshotai/Kimi-K2-Instruct-0905:groq") + if aliases: + model = aliases.get(model, model) + + # Extract the HF model ID from various formats + model_id = model + + # Strip hf. prefix if present + if model_id.startswith("hf."): + model_id = model_id[3:] + + # Strip :provider suffix if present + if ":" in model_id: + model_id = model_id.rsplit(":", 1)[0] + + # Must have org/model format to be an HF model + if "/" not in model_id: + # Not an HF model - skip validation (let ModelFactory handle it) + return ModelValidationResult(valid=True) + + try: + result = await lookup_inference_providers(model_id, lookup_fn=lookup_fn) + + if not result.exists: + return ModelValidationResult( + valid=False, + error=f"Error: Model `{model_id}` not found on HuggingFace", + ) + + if not result.has_providers: + return ModelValidationResult( + valid=False, + error=f"Error: Model `{model_id}` exists but has no inference providers available", + ) + + # Valid model with providers + providers = result.format_provider_list() + model_strings = result.format_model_strings() + example = random.choice(model_strings) + display_message = ( + f"**Available providers:** {providers}\n\n" + f"**Autoroutes if no provider specified. Example use:** `/set-model {example}`" + ) + return ModelValidationResult(valid=True, display_message=display_message) + + except Exception as e: + return ModelValidationResult( + valid=False, + error=f"Error: Failed to validate model `{model_id}`: {e}", + ) diff --git a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py index 2d455ecc3..82626f70b 100644 --- a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py @@ -369,7 +369,7 @@ def _convert_svg_resource(resource_content) -> TextBlockParam: @staticmethod def _create_fallback_text( - message: str, resource: Union[TextContent, ImageContent, EmbeddedResource] + message: str, resource: ContentBlock ) -> TextBlockParam: """ Create a fallback text block for unsupported resource types. diff --git a/src/fast_agent/mcp/helpers/content_helpers.py b/src/fast_agent/mcp/helpers/content_helpers.py index da1a41f7d..1f5b25a9b 100644 --- a/src/fast_agent/mcp/helpers/content_helpers.py +++ b/src/fast_agent/mcp/helpers/content_helpers.py @@ -3,7 +3,7 @@ """ -from typing import TYPE_CHECKING, Sequence, Union +from typing import TYPE_CHECKING, Sequence, TypeGuard, Union if TYPE_CHECKING: from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -70,22 +70,22 @@ def get_resource_uri(content: ContentBlock) -> str | None: return None -def is_text_content(content: ContentBlock) -> bool: +def is_text_content(content: ContentBlock) -> TypeGuard[TextContent | TextResourceContents]: """Check if the content is text content.""" - return isinstance(content, TextContent) or isinstance(content, TextResourceContents) + return isinstance(content, (TextContent, TextResourceContents)) -def is_image_content(content: Union[TextContent, ImageContent, EmbeddedResource]) -> bool: +def is_image_content(content: ContentBlock) -> TypeGuard[ImageContent]: """Check if the content is image content.""" return isinstance(content, ImageContent) -def is_resource_content(content: ContentBlock) -> bool: +def is_resource_content(content: ContentBlock) -> TypeGuard[EmbeddedResource]: """Check if the content is an embedded resource.""" return isinstance(content, EmbeddedResource) -def is_resource_link(content: ContentBlock) -> bool: +def is_resource_link(content: ContentBlock) -> TypeGuard[ResourceLink]: """Check if the content is a resource link.""" return isinstance(content, ResourceLink) diff --git a/src/fast_agent/mcp/mcp_aggregator.py b/src/fast_agent/mcp/mcp_aggregator.py index 890be6595..0e1af24aa 100644 --- a/src/fast_agent/mcp/mcp_aggregator.py +++ b/src/fast_agent/mcp/mcp_aggregator.py @@ -32,6 +32,7 @@ from fast_agent.event_progress import ProgressAction from fast_agent.mcp.common import SEP, create_namespaced_name, is_namespaced_name from fast_agent.mcp.gen_client import gen_client +from fast_agent.mcp.interfaces import ServerRegistryProtocol from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession from fast_agent.mcp.mcp_connection_manager import MCPConnectionManager from fast_agent.mcp.skybridge import ( @@ -46,6 +47,7 @@ if TYPE_CHECKING: from fast_agent.context import Context + from fast_agent.mcp_server_registry import ServerRegistry logger = get_logger(__name__) # This will be replaced per-instance when agent_name is available @@ -138,19 +140,15 @@ async def __aenter__(self): # Keep a connection manager to manage persistent connections for this aggregator if self.connection_persistence: + context = self._require_context() # Try to get existing connection manager from context - context = self.context if not hasattr(context, "_connection_manager") or context._connection_manager is None: - server_registry = context.server_registry - if server_registry is None: - raise RuntimeError("Context is missing server registry for MCP connections") + server_registry = cast("ServerRegistry", self._require_server_registry()) manager = MCPConnectionManager(server_registry, context=context) await manager.__aenter__() context._connection_manager = manager self._owns_connection_manager = True - self._persistent_connection_manager = cast( - "MCPConnectionManager", context._connection_manager - ) + self._persistent_connection_manager = context._connection_manager else: self._persistent_connection_manager = None @@ -231,6 +229,23 @@ def __init__( # Track discovered Skybridge configurations per server self._skybridge_configs: dict[str, SkybridgeServerConfig] = {} + def _require_context(self) -> "Context": + if self.context is None: + raise RuntimeError("MCPAggregator requires a context") + return self.context + + def _require_server_registry(self) -> ServerRegistryProtocol: + context = self._require_context() + server_registry = getattr(context, "server_registry", None) + if server_registry is None: + raise RuntimeError("Context is missing server registry for MCP connections") + return server_registry + + def _require_connection_manager(self) -> MCPConnectionManager: + if self._persistent_connection_manager is None: + raise RuntimeError("Persistent connection manager is not initialized") + return self._persistent_connection_manager + def _create_progress_callback( self, server_name: str, tool_name: str, tool_call_id: str ) -> "ProgressFnT": @@ -308,6 +323,7 @@ async def create( except Exception as e: logger.error(f"Error creating MCPAggregator: {e}") await instance.__aexit__(None, None, None) + raise def _create_session_factory(self, server_name: str): """ @@ -375,8 +391,9 @@ async def load_servers(self, *, force_connect: bool = False) -> None: for server_name in self.server_names: # Check if server should be loaded on start - if self.context and getattr(self.context, "server_registry", None): - server_config = self.context.server_registry.get_server_config(server_name) + server_registry = self.context.server_registry if self.context else None + if server_registry is not None: + server_config = server_registry.get_server_config(server_name) if ( server_config and not getattr(server_config, "load_on_start", True) @@ -398,7 +415,8 @@ async def load_servers(self, *, force_connect: bool = False) -> None: }, ) - await self._persistent_connection_manager.get_server( + manager = self._require_connection_manager() + await manager.get_server( server_name, client_session_factory=self._create_session_factory(server_name) ) @@ -717,7 +735,8 @@ async def get_capabilities(self, server_name: str): return None try: - server_conn = await self._persistent_connection_manager.get_server( + manager = self._require_connection_manager() + server_conn = await manager.get_server( server_name, client_session_factory=self._create_session_factory(server_name) ) # server_capabilities is a property, not a coroutine @@ -846,12 +865,12 @@ async def _notify_stdio_transport_activity( if not server_conn: return - server_config = getattr(server_conn, "server_config", None) - if not server_config or server_config.transport != "stdio": + server_config = server_conn.server_config + if server_config.transport != "stdio": return # Get transport metrics and emit synthetic message event - transport_metrics = getattr(server_conn, "transport_metrics", None) + transport_metrics = server_conn.transport_metrics if transport_metrics: # Import here to avoid circular imports from fast_agent.mcp.transport_tracking import ChannelEvent @@ -899,7 +918,7 @@ async def get_server_instructions(self) -> dict[str, tuple[str | None, list[str] continue try: - if hasattr(server_conn, "is_healthy") and not server_conn.is_healthy(): + if not server_conn.is_healthy(): continue except Exception: continue @@ -911,7 +930,7 @@ async def get_server_instructions(self) -> dict[str, tuple[str | None, list[str] ] try: - instructions[server_name] = (getattr(server_conn, "server_instructions", None), tool_names) + instructions[server_name] = (server_conn.server_instructions, tool_names) except Exception as e: logger.debug(f"Failed to get instructions from server {server_name}: {e}") @@ -962,36 +981,32 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: server_name, client_session_factory=self._create_session_factory(server_name), ) - implementation = getattr(server_conn, "server_implementation", None) - if implementation: - implementation_name = getattr(implementation, "name", None) - implementation_version = getattr(implementation, "version", None) - capabilities = getattr(server_conn, "server_capabilities", None) - client_capabilities = getattr(server_conn, "client_capabilities", None) + implementation = server_conn.server_implementation + if implementation is not None: + implementation_name = implementation.name + implementation_version = implementation.version + capabilities = server_conn.server_capabilities + client_capabilities = server_conn.client_capabilities session = server_conn.session client_info = getattr(session, "client_info", None) if session else None if client_info: client_info_name = getattr(client_info, "name", None) client_info_version = getattr(client_info, "version", None) is_connected = server_conn.is_healthy() - error_message = getattr(server_conn, "_error_message", None) - instructions_available = getattr( - server_conn, "server_instructions_available", None - ) - instructions_enabled = getattr(server_conn, "server_instructions_enabled", None) - instructions_included = bool(getattr(server_conn, "server_instructions", None)) - server_cfg = getattr(server_conn, "server_config", None) + error_message = server_conn._error_message + instructions_available = server_conn.server_instructions_available + instructions_enabled = server_conn.server_instructions_enabled + instructions_included = bool(server_conn.server_instructions) + server_cfg = server_conn.server_config if session: - elicitation_mode = getattr( - session, "effective_elicitation_mode", elicitation_mode - ) - session_id = getattr(server_conn, "session_id", None) - if not session_id and getattr(server_conn, "_get_session_id_cb", None): + elicitation_mode = session.effective_elicitation_mode + session_id = server_conn.session_id + if not session_id and server_conn._get_session_id_cb: try: session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined] except Exception: session_id = None - metrics = getattr(server_conn, "transport_metrics", None) + metrics = server_conn.transport_metrics if metrics is not None: try: transport_snapshot = metrics.snapshot() @@ -1007,15 +1022,13 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: data={"error": str(exc)}, ) - if ( - server_cfg is None - and self.context - and getattr(self.context, "server_registry", None) - ): - try: - server_cfg = self.context.server_registry.get_server_config(server_name) - except Exception: - server_cfg = None + if server_cfg is None: + server_registry = self.context.server_registry if self.context else None + if server_registry is not None: + try: + server_cfg = server_registry.get_server_config(server_name) + except Exception: + server_cfg = None if server_cfg is not None: instructions_enabled = ( @@ -1023,23 +1036,23 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: if instructions_enabled is not None else server_cfg.include_instructions ) - roots = getattr(server_cfg, "roots", None) + roots = server_cfg.roots roots_configured = bool(roots) roots_count = len(roots) if roots else 0 - transport = getattr(server_cfg, "transport", transport) - elicitation = getattr(server_cfg, "elicitation", None) + transport = server_cfg.transport or transport + elicitation = server_cfg.elicitation elicitation_mode = ( getattr(elicitation, "mode", None) if elicitation else elicitation_mode ) - sampling_cfg = getattr(server_cfg, "sampling", None) - spoofing_enabled = bool(getattr(server_cfg, "implementation", None)) - if implementation_name is None and getattr(server_cfg, "implementation", None): + sampling_cfg = server_cfg.sampling + spoofing_enabled = server_cfg.implementation is not None + if implementation_name is None and server_cfg.implementation is not None: implementation_name = server_cfg.implementation.name - implementation_version = getattr(server_cfg.implementation, "version", None) + implementation_version = server_cfg.implementation.version if session_id is None: if server_cfg.transport == "stdio": session_id = "local" - elif server_conn and getattr(server_conn, "_get_session_id_cb", None): + elif server_conn and server_conn._get_session_id_cb: try: session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined] except Exception: @@ -1108,7 +1121,7 @@ async def _execute_on_server( operation_type: str, operation_name: str, method_name: str, - method_args: dict[str, Any] = None, + method_args: dict[str, Any] | None = None, error_factory: Callable[[str], R] | None = None, progress_callback: ProgressFnT | None = None, ) -> R: @@ -1173,10 +1186,14 @@ async def try_execute(client: ClientSession): # Try initial execution try: if self.connection_persistence: - server_connection = await self._persistent_connection_manager.get_server( + manager = self._require_connection_manager() + server_connection = await manager.get_server( server_name, client_session_factory=self._create_session_factory(server_name) ) - result = await try_execute(server_connection.session) + session = server_connection.session + if session is None: + raise RuntimeError(f"Server session not initialized for '{server_name}'") + result = await try_execute(session) success_flag = True else: logger.debug( @@ -1187,8 +1204,9 @@ async def try_execute(client: ClientSession): "agent_name": self.agent_name, }, ) + server_registry = self._require_server_registry() async with gen_client( - server_name, server_registry=self.context.server_registry + server_name, server_registry=server_registry ) as client: result = await try_execute(client) logger.debug( @@ -1217,6 +1235,13 @@ async def try_execute(client: ClientSession): if success_flag is not None: await self._record_server_call(server_name, operation_type, success_flag) + if result is None: + error_msg = ( + f"Failed to {method_name} '{operation_name}' on server '{server_name}'" + ) + if error_factory: + return error_factory(error_msg) + raise RuntimeError(error_msg) return result async def _handle_connection_error( @@ -1233,15 +1258,20 @@ async def _handle_connection_error( try: if self.connection_persistence: # Force disconnect and create fresh connection - server_connection = await self._persistent_connection_manager.reconnect_server( + manager = self._require_connection_manager() + server_connection = await manager.reconnect_server( server_name, client_session_factory=self._create_session_factory(server_name), ) - result = await try_execute(server_connection.session) + session = server_connection.session + if session is None: + raise RuntimeError(f"Server session not initialized for '{server_name}'") + result = await try_execute(session) else: # For non-persistent connections, just try again + server_registry = self._require_server_registry() async with gen_client( - server_name, server_registry=self.context.server_registry + server_name, server_registry=server_registry ) as client: result = await try_execute(client) @@ -1287,8 +1317,9 @@ async def _handle_session_terminated( # Check if reconnect_on_disconnect is enabled for this server server_config = None - if self.context and getattr(self.context, "server_registry", None): - server_config = self.context.server_registry.get_server_config(server_name) + server_registry = self.context.server_registry if self.context else None + if server_registry is not None: + server_config = server_registry.get_server_config(server_name) reconnect_enabled = server_config and server_config.reconnect_on_disconnect @@ -1313,15 +1344,20 @@ async def _handle_session_terminated( try: if self.connection_persistence: - server_connection = await self._persistent_connection_manager.reconnect_server( + manager = self._require_connection_manager() + server_connection = await manager.reconnect_server( server_name, client_session_factory=self._create_session_factory(server_name), ) - result = await try_execute(server_connection.session) + session = server_connection.session + if session is None: + raise RuntimeError(f"Server session not initialized for '{server_name}'") + result = await try_execute(session) else: # For non-persistent connections, just try again + server_registry = self._require_server_registry() async with gen_client( - server_name, server_registry=self.context.server_registry + server_name, server_registry=server_registry ) as client: result = await try_execute(client) @@ -1359,7 +1395,9 @@ async def _handle_session_terminated( else: raise Exception(error_msg) - async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]: + async def _parse_resource_name( + self, name: str, resource_type: str + ) -> tuple[str | None, str]: """ Parse a possibly namespaced resource name into server name and local resource name. diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index 4d996fe9e..244c26882 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -185,12 +185,9 @@ async def initialize_session(self) -> None: self.server_capabilities = result.capabilities # InitializeResult exposes server info via `serverInfo`; keep fallback for older fields - implementation = getattr(result, "serverInfo", None) - if implementation is None: - implementation = getattr(result, "implementation", None) - self.server_implementation = implementation + self.server_implementation = result.serverInfo - raw_instructions = getattr(result, "instructions", None) + raw_instructions = result.instructions self.server_instructions_available = bool(raw_instructions) # Store instructions if provided by the server and enabled in config diff --git a/src/fast_agent/mcp/mcp_content.py b/src/fast_agent/mcp/mcp_content.py index 761ff7735..ac5cd24a2 100644 --- a/src/fast_agent/mcp/mcp_content.py +++ b/src/fast_agent/mcp/mcp_content.py @@ -15,6 +15,7 @@ ContentBlock, EmbeddedResource, ImageContent, + PromptMessage, ReadResourceResult, ResourceContents, TextContent, @@ -27,6 +28,7 @@ is_binary_content, is_image_mime_type, ) +from fast_agent.types import PromptMessageExtended def MCPText( @@ -87,6 +89,9 @@ def MCPImage( if not mime_type: mime_type = "image/png" # Default + if data is None: + raise ValueError("Image data is missing after path resolution") + b64_data = base64.b64encode(data).decode("ascii") return { @@ -150,7 +155,16 @@ def MCPFile( def MCPPrompt( - *content_items: Union[dict, str, Path, bytes, ContentBlock, ReadResourceResult], + *content_items: Union[ + dict, + str, + Path, + bytes, + ContentBlock, + ReadResourceResult, + PromptMessage, + PromptMessageExtended, + ], role: Literal["user", "assistant"] = "user", ) -> list[dict]: """ @@ -181,6 +195,16 @@ def MCPPrompt( if isinstance(item, dict) and "role" in item and "content" in item: # Already a fully formed message result.append(item) + elif isinstance(item, PromptMessage): + # Use the prompt message role/content directly + result.append({"role": item.role, "content": item.content}) + elif isinstance(item, PromptMessageExtended): + # Expand multipart messages into standard PromptMessages + for msg in item.from_multipart(): + result.append({"role": msg.role, "content": msg.content}) + elif isinstance(item, ContentBlock): + # Already a content block, wrap in a message + result.append({"role": role, "content": item}) elif isinstance(item, str): # Simple text content result.append(MCPText(item, role=role)) @@ -198,25 +222,23 @@ def MCPPrompt( elif isinstance(item, bytes): # Raw binary data, assume image result.append(MCPImage(data=item, role=role)) - elif isinstance(item, TextContent): - # Already a TextContent, wrap in a message - result.append({"role": role, "content": item}) - elif isinstance(item, ImageContent): - # Already an ImageContent, wrap in a message - result.append({"role": role, "content": item}) - elif isinstance(item, EmbeddedResource): - # Already an EmbeddedResource, wrap in a message - result.append({"role": role, "content": item}) elif hasattr(item, "type") and item.type == "resource" and hasattr(item, "resource"): # Looks like an EmbeddedResource but may not be the exact class - result.append( - {"role": role, "content": EmbeddedResource(type="resource", resource=item.resource)} - ) - elif isinstance(item, ResourceContents): - # It's a ResourceContents, wrap it in an EmbeddedResource + resource = item.resource + if isinstance(resource, (TextResourceContents, BlobResourceContents)): + result.append( + {"role": role, "content": EmbeddedResource(type="resource", resource=resource)} + ) + else: + result.append(MCPText(str(item), role=role)) + elif isinstance(item, (TextResourceContents, BlobResourceContents)): + # It's a concrete ResourceContents, wrap it in an EmbeddedResource result.append( {"role": role, "content": EmbeddedResource(type="resource", resource=item)} ) + elif isinstance(item, ResourceContents): + # Fallback for unknown resource content shapes + result.append(MCPText(str(item), role=role)) elif isinstance(item, ReadResourceResult): # It's a ReadResourceResult, convert each resource content for resource_content in item.contents: @@ -234,14 +256,32 @@ def MCPPrompt( def User( - *content_items: Union[dict, str, Path, bytes, ContentBlock, ReadResourceResult], + *content_items: Union[ + dict, + str, + Path, + bytes, + ContentBlock, + ReadResourceResult, + PromptMessage, + PromptMessageExtended, + ], ) -> list[dict]: """Create user message(s) with various content types.""" return MCPPrompt(*content_items, role="user") def Assistant( - *content_items: Union[dict, str, Path, bytes, ContentBlock, ReadResourceResult], + *content_items: Union[ + dict, + str, + Path, + bytes, + ContentBlock, + ReadResourceResult, + PromptMessage, + PromptMessageExtended, + ], ) -> list[dict]: """Create assistant message(s) with various content types.""" return MCPPrompt(*content_items, role="assistant") diff --git a/src/fast_agent/mcp/prompt_serialization.py b/src/fast_agent/mcp/prompt_serialization.py index 2f6b1527d..7a2aab40a 100644 --- a/src/fast_agent/mcp/prompt_serialization.py +++ b/src/fast_agent/mcp/prompt_serialization.py @@ -294,7 +294,7 @@ def multipart_messages_to_delimited_format( # First, add all text content for content in message.content: - if content.type == "text": + if isinstance(content, TextContent): # Collect text content to combine text_contents.append(content.text) @@ -304,7 +304,7 @@ def multipart_messages_to_delimited_format( # Then add resources and images for content in message.content: - if content.type != "text": + if not isinstance(content, TextContent): # Resource or image - add delimiter and JSON delimited_content.append(resource_delimiter) @@ -316,7 +316,7 @@ def multipart_messages_to_delimited_format( else: # Don't combine text contents - preserve each content part in sequence for content in message.content: - if content.type == "text": + if isinstance(content, TextContent): # Add each text content separately delimited_content.append(content.text) else: diff --git a/src/fast_agent/mcp/prompts/prompt_server.py b/src/fast_agent/mcp/prompts/prompt_server.py index c511622c4..2829639e8 100644 --- a/src/fast_agent/mcp/prompts/prompt_server.py +++ b/src/fast_agent/mcp/prompts/prompt_server.py @@ -11,7 +11,7 @@ import logging import sys from pathlib import Path -from typing import Any, Awaitable, Callable, Union +from typing import Any, Awaitable, Callable, Sequence, Union from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.prompts.base import ( @@ -48,7 +48,9 @@ mcp = FastMCP("Prompt Server") -def convert_to_fastmcp_messages(prompt_messages: list[Union[PromptMessage, PromptMessageExtended]]) -> list[Message]: +def convert_to_fastmcp_messages( + prompt_messages: Sequence[Union[PromptMessage, PromptMessageExtended]], +) -> list[Message]: """ Convert PromptMessage or PromptMessageExtended objects to FastMCP Message objects. This adapter prevents double-wrapping of messages and handles both types. @@ -63,26 +65,19 @@ def convert_to_fastmcp_messages(prompt_messages: list[Union[PromptMessage, Promp for msg in prompt_messages: # Handle both PromptMessage and PromptMessageExtended - if hasattr(msg, 'from_multipart'): - # PromptMessageExtended - convert to regular PromptMessage format + if isinstance(msg, PromptMessageExtended): flat_messages = msg.from_multipart() - for flat_msg in flat_messages: - if flat_msg.role == "user": - result.append(UserMessage(content=flat_msg.content)) - elif flat_msg.role == "assistant": - result.append(AssistantMessage(content=flat_msg.content)) - else: - logger.warning(f"Unknown message role: {flat_msg.role}, defaulting to user") - result.append(UserMessage(content=flat_msg.content)) else: - # Regular PromptMessage - use directly - if msg.role == "user": - result.append(UserMessage(content=msg.content)) - elif msg.role == "assistant": - result.append(AssistantMessage(content=msg.content)) + flat_messages = [msg] + + for flat_msg in flat_messages: + if flat_msg.role == "user": + result.append(UserMessage(content=flat_msg.content)) + elif flat_msg.role == "assistant": + result.append(AssistantMessage(content=flat_msg.content)) else: - logger.warning(f"Unknown message role: {msg.role}, defaulting to user") - result.append(UserMessage(content=msg.content)) + logger.warning(f"Unknown message role: {flat_msg.role}, defaulting to user") + result.append(UserMessage(content=flat_msg.content)) return result diff --git a/src/fast_agent/mcp/prompts/prompt_template.py b/src/fast_agent/mcp/prompts/prompt_template.py index a82324e3a..1ef9647eb 100644 --- a/src/fast_agent/mcp/prompts/prompt_template.py +++ b/src/fast_agent/mcp/prompts/prompt_template.py @@ -7,7 +7,7 @@ import re from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast from mcp.types import ( EmbeddedResource, @@ -16,15 +16,14 @@ ) from pydantic import BaseModel, field_validator -from fast_agent.mcp.prompt_serialization import ( - multipart_messages_to_delimited_format, -) +from fast_agent.mcp.prompt_serialization import multipart_messages_to_delimited_format from fast_agent.mcp.prompts.prompt_constants import ( ASSISTANT_DELIMITER, DEFAULT_DELIMITER_MAP, RESOURCE_DELIMITER, USER_DELIMITER, ) +from fast_agent.mcp.resource_utils import to_any_url from fast_agent.types import PromptMessageExtended @@ -46,16 +45,16 @@ class PromptContent(BaseModel): """Content of a prompt, which may include template variables""" text: str - role: str = "user" + role: MessageRole = "user" resources: list[str] = [] @field_validator("role") @classmethod - def validate_role(cls, role: str) -> str: + def validate_role(cls, role: str) -> MessageRole: """Validate that the role is a known value""" if role not in ("user", "assistant"): raise ValueError(f"Invalid role: {role}. Must be one of: user, assistant") - return role + return cast("MessageRole", role) def apply_substitutions(self, context: dict[str, Any]) -> "PromptContent": """Apply variable substitutions to the text and resources""" @@ -194,7 +193,7 @@ def apply_substitutions_to_extended( EmbeddedResource( type="resource", resource=TextResourceContents( - uri=f"resource://fast-agent/{resource_path}", + uri=to_any_url(f"resource://fast-agent/{resource_path}"), mimeType="text/plain", text=f"Content of {resource_path}", ), @@ -232,7 +231,7 @@ def to_extended_messages(self) -> list[PromptMessageExtended]: EmbeddedResource( type="resource", resource=TextResourceContents( - uri=f"resource://{resource_path}", + uri=to_any_url(f"resource://{resource_path}"), mimeType="text/plain", text=f"Content of {resource_path}", ), diff --git a/src/fast_agent/mcp/resource_utils.py b/src/fast_agent/mcp/resource_utils.py index 4f48fd525..880a97ba2 100644 --- a/src/fast_agent/mcp/resource_utils.py +++ b/src/fast_agent/mcp/resource_utils.py @@ -7,7 +7,7 @@ ImageContent, TextResourceContents, ) -from pydantic import AnyUrl +from pydantic import AnyUrl, TypeAdapter import fast_agent.mcp.mime_utils as mime_utils @@ -65,12 +65,20 @@ def load_resource_content(resource_path: str, prompt_files: list[Path]) -> Resou # Create a safe way to generate resource URIs that Pydantic accepts -def create_resource_uri(path: str) -> str: +_ANY_URL_ADAPTER = TypeAdapter(AnyUrl) + + +def to_any_url(value: str | AnyUrl) -> AnyUrl: + """Normalize a URI string to AnyUrl (validates via pydantic).""" + return _ANY_URL_ADAPTER.validate_python(value) + + +def create_resource_uri(path: str) -> AnyUrl: """Create a resource URI from a path""" - return f"resource://fast-agent/{Path(path).name}" + return to_any_url(f"resource://fast-agent/{Path(path).name}") -def create_resource_reference(uri: str, mime_type: str) -> "EmbeddedResource": +def create_resource_reference(uri: AnyUrl | str, mime_type: str) -> "EmbeddedResource": """ Create a reference to a resource without embedding its content directly. @@ -89,7 +97,7 @@ def create_resource_reference(uri: str, mime_type: str) -> "EmbeddedResource": # Create a resource reference resource_contents = TextResourceContents( - uri=uri, + uri=to_any_url(uri), mimeType=mime_type, text="", # Empty text as we're just referencing ) @@ -104,17 +112,12 @@ def create_embedded_resource( # Format a valid resource URI string resource_uri_str = create_resource_uri(resource_path) - # Create common resource args dict to reduce duplication - resource_args = { - "uri": resource_uri_str, # type: ignore - "mimeType": mime_type, - } - if is_binary: return EmbeddedResource( type="resource", resource=BlobResourceContents( - **resource_args, + uri=resource_uri_str, + mimeType=mime_type, blob=content, ), ) @@ -122,7 +125,8 @@ def create_embedded_resource( return EmbeddedResource( type="resource", resource=TextResourceContents( - **resource_args, + uri=resource_uri_str, + mimeType=mime_type, text=content, ), ) @@ -137,24 +141,28 @@ def create_image_content(data: str, mime_type: str) -> ImageContent: ) -def create_blob_resource(resource_path: str, content: str, mime_type: str) -> EmbeddedResource: +def create_blob_resource( + resource_path: str | AnyUrl, content: str, mime_type: str +) -> EmbeddedResource: """Create an embedded resource for binary data""" return EmbeddedResource( type="resource", resource=BlobResourceContents( - uri=resource_path, + uri=to_any_url(resource_path), mimeType=mime_type, blob=content, # Content should already be base64 encoded ), ) -def create_text_resource(resource_path: str, content: str, mime_type: str) -> EmbeddedResource: +def create_text_resource( + resource_path: str | AnyUrl, content: str, mime_type: str +) -> EmbeddedResource: """Create an embedded resource for text data""" return EmbeddedResource( type="resource", resource=TextResourceContents( - uri=resource_path, + uri=to_any_url(resource_path), mimeType=mime_type, text=content, ), @@ -193,12 +201,13 @@ def normalize_uri(uri_or_filename: str) -> str: def extract_title_from_uri(uri: AnyUrl) -> str: """Extract a readable title from a URI.""" # Simple attempt to get filename from path - uri_str = uri._url + uri_str = str(uri) try: # For HTTP(S) URLs if uri.scheme in ("http", "https"): # Get the last part of the path - path_parts = uri.path.split("/") + path = uri.path or "" + path_parts = path.split("/") if path else [] filename = next((p for p in reversed(path_parts) if p), "") return filename if filename else uri_str diff --git a/src/fast_agent/mcp/server/agent_server.py b/src/fast_agent/mcp/server/agent_server.py index 780cb17bd..841da3dbb 100644 --- a/src/fast_agent/mcp/server/agent_server.py +++ b/src/fast_agent/mcp/server/agent_server.py @@ -8,7 +8,7 @@ import signal import time from contextlib import AsyncExitStack, asynccontextmanager -from typing import Awaitable, Callable +from typing import Any, AsyncContextManager, Awaitable, Callable, Literal, Protocol, cast from mcp.server.fastmcp import Context as MCPContext from mcp.server.fastmcp import FastMCP @@ -21,6 +21,22 @@ logger = get_logger(__name__) +TransportMode = Literal["http", "sse", "stdio"] +McpTransportMode = Literal["streamable-http", "sse", "stdio"] + + +class _LocalSseTransport(Protocol): + connect_sse: Callable[..., AsyncContextManager[Any]] + _read_stream_writers: dict[Any, Any] + + +class _FastMCPLocalExtensions(Protocol): + _sse_transport: _LocalSseTransport | None + _lifespan_state: str + _on_shutdown: Callable[[], Awaitable[None]] + _server_should_exit: bool + + class AgentMCPServer: """Exposes FastAgent agents as MCP tools through an MCP server.""" @@ -56,7 +72,7 @@ def __init__( # Resource management self._exit_stack = AsyncExitStack() - self._active_connections: set[any] = set() + self._active_connections: set[object] = set() # Server state self._server_task = None @@ -282,7 +298,12 @@ async def _handle_shutdown_signal(self, is_term=False): print("Press Ctrl+C again to force exit.") self._graceful_shutdown_event.set() - def run(self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000) -> None: + def run( + self, + transport: TransportMode = "http", + host: str = "0.0.0.0", + port: int = 8000, + ) -> None: """Run the MCP server synchronously.""" if transport in ["sse", "http"]: self.mcp_server.settings.host = host @@ -292,10 +313,13 @@ def run(self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000) try: # Add any server attributes that might help with shutdown if not hasattr(self.mcp_server, "_server_should_exit"): - self.mcp_server._server_should_exit = False + setattr(self.mcp_server, "_server_should_exit", False) # Run the server - self.mcp_server.run(transport=transport) + mcp_transport: McpTransportMode = ( + "streamable-http" if transport == "http" else transport + ) + self.mcp_server.run(transport=mcp_transport) except KeyboardInterrupt: print("\nServer stopped by user (CTRL+C)") except SystemExit as e: @@ -314,7 +338,7 @@ def run(self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000) pass else: # stdio try: - self.mcp_server.run(transport=transport) + self.mcp_server.run(transport="stdio") except KeyboardInterrupt: print("\nServer stopped by user (CTRL+C)") finally: @@ -322,7 +346,7 @@ def run(self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000) asyncio.run(self._cleanup_stdio()) async def run_async( - self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000 + self, transport: TransportMode = "http", host: str = "0.0.0.0", port: int = 8000 ) -> None: """Run the MCP server asynchronously with improved shutdown handling.""" # Use different handling strategies based on transport type @@ -334,7 +358,15 @@ async def run_async( self.mcp_server.settings.port = port # Start the server in a separate task so we can monitor it - self._server_task = asyncio.create_task(self._run_server_with_shutdown(transport)) + if transport == "http": + http_transport: Literal["http", "sse"] = "http" + elif transport == "sse": + http_transport = "sse" + else: + raise ValueError("HTTP/SSE handler received stdio transport") + self._server_task = asyncio.create_task( + self._run_server_with_shutdown(http_transport) + ) try: # Wait for the server task to complete @@ -376,7 +408,7 @@ async def run_async( # Only perform minimal cleanup needed for STDIO await self._cleanup_stdio() - async def _run_server_with_shutdown(self, transport: str): + async def _run_server_with_shutdown(self, transport: Literal["http", "sse"]): """Run the server with proper shutdown handling.""" # This method is used for SSE/HTTP transport if transport not in ["sse", "http"]: @@ -387,9 +419,11 @@ async def _run_server_with_shutdown(self, transport: str): try: # Patch SSE server to track connections if needed - if hasattr(self.mcp_server, "_sse_transport") and self.mcp_server._sse_transport: + mcp_ext = cast("_FastMCPLocalExtensions", self.mcp_server) + sse_transport = getattr(mcp_ext, "_sse_transport", None) + if sse_transport is not None: # Store the original connect_sse method - original_connect = self.mcp_server._sse_transport.connect_sse + original_connect = sse_transport.connect_sse # Create a wrapper that tracks connections @asynccontextmanager @@ -402,7 +436,7 @@ async def tracked_connect_sse(*args, **kwargs): self._active_connections.discard(streams) # Replace with our tracking version - self.mcp_server._sse_transport.connect_sse = tracked_connect_sse + sse_transport.connect_sse = tracked_connect_sse # Run the server based on transport type if transport == "sse": @@ -459,54 +493,52 @@ async def _close_sse_connections(self): # Close tracked connections for conn in list(self._active_connections): try: - if hasattr(conn, "close"): - await conn.close() - elif hasattr(conn, "aclose"): - await conn.aclose() + close = getattr(conn, "close", None) + if callable(close): + await close() + else: + aclose = getattr(conn, "aclose", None) + if callable(aclose): + await aclose() except Exception as e: logger.error(f"Error closing connection: {e}") self._active_connections.discard(conn) # Access the SSE transport if it exists to close stream writers - if ( - hasattr(self.mcp_server, "_sse_transport") - and self.mcp_server._sse_transport is not None - ): - sse = self.mcp_server._sse_transport + mcp_ext = cast("_FastMCPLocalExtensions", self.mcp_server) + sse = getattr(mcp_ext, "_sse_transport", None) + if sse is not None: # Close all read stream writers - if hasattr(sse, "_read_stream_writers"): - writers = list(sse._read_stream_writers.items()) - for session_id, writer in writers: + writers = list(sse._read_stream_writers.items()) + for session_id, writer in writers: + try: + logger.debug(f"Closing SSE connection: {session_id}") + # Instead of aclose, try to close more gracefully + # Send a special event to notify client, then close try: - logger.debug(f"Closing SSE connection: {session_id}") - # Instead of aclose, try to close more gracefully - # Send a special event to notify client, then close - try: - if hasattr(writer, "send") and not getattr(writer, "_closed", False): - try: - # Try to send a close event if possible - await writer.send(Exception("Server shutting down")) - except (AttributeError, asyncio.CancelledError): - pass - except Exception: - pass - - # Now close the stream - await writer.aclose() - sse._read_stream_writers.pop(session_id, None) - except Exception as e: - logger.error(f"Error closing SSE connection {session_id}: {e}") + if hasattr(writer, "send") and not getattr(writer, "_closed", False): + try: + # Try to send a close event if possible + await writer.send(Exception("Server shutting down")) + except (AttributeError, asyncio.CancelledError): + pass + except Exception: + pass + + # Now close the stream + await writer.aclose() + sse._read_stream_writers.pop(session_id, None) + except Exception as e: + logger.error(f"Error closing SSE connection {session_id}: {e}") # If we have a ASGI lifespan hook, try to signal closure - if ( - hasattr(self.mcp_server, "_lifespan_state") - and self.mcp_server._lifespan_state == "started" - ): + if getattr(mcp_ext, "_lifespan_state", None) == "started": logger.debug("Attempting to signal ASGI lifespan shutdown") try: - if hasattr(self.mcp_server, "_on_shutdown"): - await self.mcp_server._on_shutdown() + on_shutdown = getattr(mcp_ext, "_on_shutdown", None) + if on_shutdown is not None: + await on_shutdown() except Exception as e: logger.error(f"Error during ASGI lifespan shutdown: {e}") @@ -594,20 +626,17 @@ async def _cleanup_minimal(self): logger.info("Performing minimal cleanup before interrupt") # Only close SSE connection writers directly - if ( - hasattr(self.mcp_server, "_sse_transport") - and self.mcp_server._sse_transport is not None - ): - sse = self.mcp_server._sse_transport + mcp_ext = cast("_FastMCPLocalExtensions", self.mcp_server) + sse = getattr(mcp_ext, "_sse_transport", None) + if sse is not None: # Close all read stream writers - if hasattr(sse, "_read_stream_writers"): - for session_id, writer in list(sse._read_stream_writers.items()): - try: - await writer.aclose() - except Exception: - # Ignore errors during cleanup - pass + for session_id, writer in list(sse._read_stream_writers.items()): + try: + await writer.aclose() + except Exception: + # Ignore errors during cleanup + pass # Clear active connections set to prevent further operations self._active_connections.clear() diff --git a/src/fast_agent/types/__init__.py b/src/fast_agent/types/__init__.py index 0ff3b1897..ce857e0ed 100644 --- a/src/fast_agent/types/__init__.py +++ b/src/fast_agent/types/__init__.py @@ -35,6 +35,9 @@ # Message search utilities from .message_search import extract_first, extract_last, find_matches, search_messages +# Tool timing metadata +from .tool_timing import ToolTimingInfo, ToolTimings + __all__ = [ # Enums / types "LlmStopReason", @@ -56,4 +59,7 @@ "find_matches", "extract_first", "extract_last", + # Tool timing types + "ToolTimingInfo", + "ToolTimings", ] diff --git a/src/fast_agent/types/tool_timing.py b/src/fast_agent/types/tool_timing.py new file mode 100644 index 000000000..0df6faf55 --- /dev/null +++ b/src/fast_agent/types/tool_timing.py @@ -0,0 +1,11 @@ +"""Tool timing metadata types shared across agents and UI.""" + +from typing import TypeAlias, TypedDict + + +class ToolTimingInfo(TypedDict): + timing_ms: float + transport_channel: str | None + + +ToolTimings: TypeAlias = dict[str, ToolTimingInfo] diff --git a/src/fast_agent/ui/command_payloads.py b/src/fast_agent/ui/command_payloads.py index c6d96352f..86f013715 100644 --- a/src/fast_agent/ui/command_payloads.py +++ b/src/fast_agent/ui/command_payloads.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal +from typing import Literal, TypeGuard class CommandBase: @@ -102,3 +102,7 @@ class LoadHistoryCommand(CommandBase): | SaveHistoryCommand | LoadHistoryCommand ) + + +def is_command_payload(value: object) -> TypeGuard[CommandPayload]: + return isinstance(value, CommandBase) diff --git a/src/fast_agent/ui/enhanced_prompt.py b/src/fast_agent/ui/enhanced_prompt.py index a657c9532..40d157571 100644 --- a/src/fast_agent/ui/enhanced_prompt.py +++ b/src/fast_agent/ui/enhanced_prompt.py @@ -11,7 +11,7 @@ from collections.abc import Callable, Iterable from importlib.metadata import version from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from prompt_toolkit import PromptSession from prompt_toolkit.completion import Completer, Completion, WordCompleter @@ -31,7 +31,6 @@ from fast_agent.mcp.types import McpAgentProtocol from fast_agent.ui.command_payloads import ( ClearCommand, - CommandBase, CommandPayload, ListToolsCommand, LoadHistoryCommand, @@ -44,6 +43,7 @@ ShowUsageCommand, SkillsCommand, SwitchAgentCommand, + is_command_payload, ) from fast_agent.ui.mcp_display import render_mcp_status @@ -1385,8 +1385,8 @@ async def handle_special_commands( return False # If command is already a command payload, it has been pre-processed. - if isinstance(command, CommandBase): - return command + if is_command_payload(command): + return cast("CommandPayload", command) global agent_histories diff --git a/src/fast_agent/ui/interactive_prompt.py b/src/fast_agent/ui/interactive_prompt.py index f51e77fe8..ce16ce2f6 100644 --- a/src/fast_agent/ui/interactive_prompt.py +++ b/src/fast_agent/ui/interactive_prompt.py @@ -22,7 +22,6 @@ if TYPE_CHECKING: from fast_agent.core.agent_app import AgentApp - from fast_agent.ui.command_payloads import CommandPayload from mcp.types import Prompt, PromptMessage from rich import print as rich_print @@ -48,7 +47,24 @@ ) from fast_agent.skills.registry import format_skills_for_prompt from fast_agent.types import PromptMessageExtended -from fast_agent.ui.command_payloads import CommandBase +from fast_agent.ui.command_payloads import ( + ClearCommand, + CommandPayload, + ListPromptsCommand, + ListSkillsCommand, + ListToolsCommand, + LoadHistoryCommand, + SaveHistoryCommand, + SelectPromptCommand, + ShowHistoryCommand, + ShowMarkdownCommand, + ShowMcpStatusCommand, + ShowSystemCommand, + ShowUsageCommand, + SkillsCommand, + SwitchAgentCommand, + is_command_payload, +) from fast_agent.ui.enhanced_prompt import ( _display_agent_info_helper, get_argument_input, @@ -140,197 +156,199 @@ async def prompt_loop( command_result = await handle_special_commands(user_input, True) # Check if we should switch agents - if isinstance(command_result, CommandBase): - command_payload: CommandPayload = command_result - kind = command_payload.kind - if kind == "switch_agent": - new_agent = command_payload.agent_name - if new_agent in available_agents_set: - agent = new_agent - # Display new agent info immediately when switching - rich_print() # Add spacing - await _display_agent_info_helper(agent, prompt_provider) - continue - else: + if is_command_payload(command_result): + command_payload: CommandPayload = cast("CommandPayload", command_result) + match command_payload: + case SwitchAgentCommand(agent_name=new_agent): + if new_agent in available_agents_set: + agent = new_agent + # Display new agent info immediately when switching + rich_print() # Add spacing + await _display_agent_info_helper(agent, prompt_provider) + continue rich_print(f"[red]Agent '{new_agent}' not found[/red]") continue - # Keep the existing list_prompts handler for backward compatibility - elif kind == "list_prompts": - # Use the prompt_provider directly - await self._list_prompts(prompt_provider, agent) - continue - elif kind == "select_prompt": - # Handle prompt selection, using both list_prompts and apply_prompt - prompt_name = command_payload.prompt_name - prompt_index = command_payload.prompt_index - - # If a specific index was provided (from /prompt ) - if prompt_index is not None: - # First get a list of all prompts to look up the index - all_prompts = await self._get_all_prompts(prompt_provider, agent) - if not all_prompts: - rich_print("[yellow]No prompts available[/yellow]") + # Keep the existing list_prompts handler for backward compatibility + case ListPromptsCommand(): + # Use the prompt_provider directly + await self._list_prompts(prompt_provider, agent) + continue + case SelectPromptCommand( + prompt_name=prompt_name, prompt_index=prompt_index + ): + # Handle prompt selection, using both list_prompts and apply_prompt + # If a specific index was provided (from /prompt ) + if prompt_index is not None: + # First get a list of all prompts to look up the index + all_prompts = await self._get_all_prompts(prompt_provider, agent) + if not all_prompts: + rich_print("[yellow]No prompts available[/yellow]") + continue + + # Check if the index is valid + if 1 <= prompt_index <= len(all_prompts): + # Get the prompt at the specified index (1-based to 0-based) + selected_prompt = all_prompts[prompt_index - 1] + # Use the already created namespaced_name to ensure consistency + await self._select_prompt( + prompt_provider, + agent, + selected_prompt["namespaced_name"], + ) + else: + rich_print( + f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]" + ) + # Show the prompt list for convenience + await self._list_prompts(prompt_provider, agent) + else: + # Use the name-based selection + await self._select_prompt(prompt_provider, agent, prompt_name) + continue + case ListToolsCommand(): + # Handle tools list display + await self._list_tools(prompt_provider, agent) + continue + case ListSkillsCommand(): + await self._list_skills(prompt_provider, agent) + continue + case SkillsCommand(action=action, argument=argument): + payload = {"action": action, "argument": argument} + await self._handle_skills_command(prompt_provider, agent, payload) + continue + case ShowUsageCommand(): + # Handle usage display + await self._show_usage(prompt_provider, agent) + continue + case ShowHistoryCommand(agent=target_agent): + target_agent = target_agent or agent + try: + agent_obj = prompt_provider._agent(target_agent) + except Exception: + rich_print(f"[red]Unable to load agent '{target_agent}'[/red]") + continue + + history = getattr(agent_obj, "message_history", []) + usage = getattr(agent_obj, "usage_accumulator", None) + display_history_overview(target_agent, history, usage) + continue + case ClearCommand(kind="clear_last", agent=target_agent): + target_agent = target_agent or agent + try: + agent_obj = prompt_provider._agent(target_agent) + except Exception: + rich_print(f"[red]Unable to load agent '{target_agent}'[/red]") continue - # Check if the index is valid - if 1 <= prompt_index <= len(all_prompts): - # Get the prompt at the specified index (1-based to 0-based) - selected_prompt = all_prompts[prompt_index - 1] - # Use the already created namespaced_name to ensure consistency - await self._select_prompt( - prompt_provider, - agent, - selected_prompt["namespaced_name"], + removed_message = None + pop_callable = getattr(agent_obj, "pop_last_message", None) + if callable(pop_callable): + removed_message = pop_callable() + else: + history = getattr(agent_obj, "message_history", []) + if history: + try: + removed_message = history.pop() + except Exception: + removed_message = None + + if removed_message: + role = getattr(removed_message, "role", "message") + role_display = ( + role.capitalize() if isinstance(role, str) else "Message" + ) + rich_print( + f"[green]Removed last {role_display} for agent '{target_agent}'.[/green]" ) else: rich_print( - f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]" + f"[yellow]No messages to remove for agent '{target_agent}'.[/yellow]" ) - # Show the prompt list for convenience - await self._list_prompts(prompt_provider, agent) - else: - # Use the name-based selection - await self._select_prompt(prompt_provider, agent, prompt_name) - continue - elif kind == "list_tools": - # Handle tools list display - await self._list_tools(prompt_provider, agent) - continue - elif kind == "list_skills": - await self._list_skills(prompt_provider, agent) - continue - elif kind == "skills_command": - payload = { - "action": command_payload.action, - "argument": command_payload.argument, - } - await self._handle_skills_command(prompt_provider, agent, payload) - continue - elif kind == "show_usage": - # Handle usage display - await self._show_usage(prompt_provider, agent) - continue - elif kind == "show_history": - target_agent = command_payload.agent or agent - try: - agent_obj = prompt_provider._agent(target_agent) - except Exception: - rich_print(f"[red]Unable to load agent '{target_agent}'[/red]") - continue - - history = getattr(agent_obj, "message_history", []) - usage = getattr(agent_obj, "usage_accumulator", None) - display_history_overview(target_agent, history, usage) - continue - elif kind == "clear_last": - target_agent = command_payload.agent or agent - try: - agent_obj = prompt_provider._agent(target_agent) - except Exception: - rich_print(f"[red]Unable to load agent '{target_agent}'[/red]") continue + case ClearCommand(kind="clear_history", agent=target_agent): + target_agent = target_agent or agent + try: + agent_obj = prompt_provider._agent(target_agent) + except Exception: + rich_print(f"[red]Unable to load agent '{target_agent}'[/red]") + continue - removed_message = None - pop_callable = getattr(agent_obj, "pop_last_message", None) - if callable(pop_callable): - removed_message = pop_callable() - else: - history = getattr(agent_obj, "message_history", []) - if history: + if hasattr(agent_obj, "clear"): try: - removed_message = history.pop() - except Exception: - removed_message = None - - if removed_message: - role = getattr(removed_message, "role", "message") - role_display = role.capitalize() if isinstance(role, str) else "Message" - rich_print( - f"[green]Removed last {role_display} for agent '{target_agent}'.[/green]" - ) - else: - rich_print( - f"[yellow]No messages to remove for agent '{target_agent}'.[/yellow]" - ) - continue - elif kind == "clear_history": - target_agent = command_payload.agent or agent - try: - agent_obj = prompt_provider._agent(target_agent) - except Exception: - rich_print(f"[red]Unable to load agent '{target_agent}'[/red]") - continue - - if hasattr(agent_obj, "clear"): - try: - agent_obj.clear() + agent_obj.clear() + rich_print( + f"[green]History cleared for agent '{target_agent}'.[/green]" + ) + except Exception as exc: + rich_print( + f"[red]Failed to clear history for '{target_agent}': {exc}[/red]" + ) + else: rich_print( - f"[green]History cleared for agent '{target_agent}'.[/green]" + f"[yellow]Agent '{target_agent}' does not support clearing history.[/yellow]" ) - except Exception as exc: - rich_print( - f"[red]Failed to clear history for '{target_agent}': {exc}[/red]" + continue + case ShowSystemCommand(): + # Handle system prompt display + await self._show_system(prompt_provider, agent) + continue + case ShowMarkdownCommand(): + # Handle markdown display + await self._show_markdown(prompt_provider, agent) + continue + case ShowMcpStatusCommand(): + rich_print() + await show_mcp_status(agent, prompt_provider) + continue + case SaveHistoryCommand(filename=filename): + # Save history for the current agent + try: + agent_obj = prompt_provider._agent(agent) + + # Prefer type-safe exporter over magic string + saved_path = await HistoryExporter.save(agent_obj, filename) + rich_print(f"[green]History saved to {saved_path}[/green]") + except Exception: + # Fallback to magic string path for maximum compatibility + control = CONTROL_MESSAGE_SAVE_HISTORY + ( + f" {filename}" if filename else "" ) - else: - rich_print( - f"[yellow]Agent '{target_agent}' does not support clearing history.[/yellow]" - ) - continue - elif kind == "show_system": - # Handle system prompt display - await self._show_system(prompt_provider, agent) - continue - elif kind == "show_markdown": - # Handle markdown display - await self._show_markdown(prompt_provider, agent) - continue - elif kind == "show_mcp_status": - rich_print() - await show_mcp_status(agent, prompt_provider) - continue - elif kind == "save_history": - # Save history for the current agent - filename = command_payload.filename - try: - agent_obj = prompt_provider._agent(agent) - - # Prefer type-safe exporter over magic string - saved_path = await HistoryExporter.save(agent_obj, filename) - rich_print(f"[green]History saved to {saved_path}[/green]") - except Exception: - # Fallback to magic string path for maximum compatibility - control = CONTROL_MESSAGE_SAVE_HISTORY + ( - f" {filename}" if filename else "" - ) - result = await send_func(control, agent) - if result: - rich_print(f"[green]{result}[/green]") - continue - elif kind == "load_history": - # Load history for the current agent - if command_payload.error: - rich_print(f"[red]{command_payload.error}[/red]") + result = await send_func(control, agent) + if result: + rich_print(f"[green]{result}[/green]") continue + case LoadHistoryCommand(filename=filename, error=error): + # Load history for the current agent + if error: + rich_print(f"[red]{error}[/red]") + continue + + if filename is None: + rich_print("[red]Filename required for load_history[/red]") + continue - filename = command_payload.filename - try: - from fast_agent.mcp.prompts.prompt_load import load_history_into_agent + try: + from fast_agent.mcp.prompts.prompt_load import ( + load_history_into_agent, + ) - # Get the agent object and its underlying LLM - agent_obj = prompt_provider._agent(agent) + # Get the agent object and its underlying LLM + agent_obj = prompt_provider._agent(agent) - # Load history directly without triggering LLM call - load_history_into_agent(agent_obj, Path(filename)) + # Load history directly without triggering LLM call + load_history_into_agent(agent_obj, Path(filename)) - msg_count = len(agent_obj.message_history) - rich_print( - f"[green]Loaded {msg_count} messages from {filename}[/green]" - ) - except FileNotFoundError: - rich_print(f"[red]File not found: {filename}[/red]") - except Exception as e: - rich_print(f"[red]Error loading history: {e}[/red]") - continue + msg_count = len(agent_obj.message_history) + rich_print( + f"[green]Loaded {msg_count} messages from {filename}[/green]" + ) + except FileNotFoundError: + rich_print(f"[red]File not found: {filename}[/red]") + except Exception as e: + rich_print(f"[red]Error loading history: {e}[/red]") + continue + case _: + pass # Skip further processing if: # 1. The command was handled (command_result is truthy) @@ -339,8 +357,8 @@ async def prompt_loop( # This fixes the issue where /prompt without arguments gets sent to the LLM if ( command_result - or isinstance(user_input, CommandBase) - or isinstance(command_result, CommandBase) + or is_command_payload(user_input) + or is_command_payload(command_result) ): continue @@ -1199,7 +1217,7 @@ async def _refresh_agent_skills(self, prompt_provider: "AgentApp", agent_name: s await rebuild_agent_instruction( agent, skill_manifests=manifests, - instruction_context=instruction_context, + context=instruction_context, skill_registry=registry, ) diff --git a/src/fast_agent/ui/mcp_display.py b/src/fast_agent/ui/mcp_display.py index 116f67a37..17490f54d 100644 --- a/src/fast_agent/ui/mcp_display.py +++ b/src/fast_agent/ui/mcp_display.py @@ -303,10 +303,7 @@ def _build_capability_text(tokens: list[tuple[str, str]]) -> Text: def _format_relative_time(dt: datetime | None) -> str: if dt is None: return "never" - try: - now = datetime.now(timezone.utc) - except Exception: - now = datetime.utcnow().replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) seconds = max(0, (now - dt).total_seconds()) return _format_compact_duration(seconds) or "<1s" diff --git a/src/fast_agent/ui/mcp_ui_utils.py b/src/fast_agent/ui/mcp_ui_utils.py index bfe8ac194..ae1ea663a 100644 --- a/src/fast_agent/ui/mcp_ui_utils.py +++ b/src/fast_agent/ui/mcp_ui_utils.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Iterable -from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents +from mcp.types import BlobResourceContents, ContentBlock, EmbeddedResource, TextResourceContents """ Utilities for handling MCP-UI resources carried in PromptMessageExtended.channels. @@ -126,7 +126,7 @@ def _write_html_file(name_hint: str, html: str) -> str: return str(out_path.resolve()) -def ui_links_from_channel(resources: Iterable[EmbeddedResource]) -> list[UILink]: +def ui_links_from_channel(resources: Iterable[ContentBlock]) -> list[UILink]: """ Build local HTML files for a list of MCP-UI EmbeddedResources and return clickable links. @@ -136,7 +136,10 @@ def ui_links_from_channel(resources: Iterable[EmbeddedResource]) -> list[UILink] - application/vnd.mcp-ui.remote-dom* : currently unsupported; generate a placeholder page """ links: list[UILink] = [] - for emb in resources: + for item in resources: + if not isinstance(item, EmbeddedResource): + continue + emb = item res = emb.resource uri = str(getattr(res, "uri", "")) if getattr(res, "uri", None) else None mime = getattr(res, "mimeType", "") or "" diff --git a/src/fast_agent/ui/tool_display.py b/src/fast_agent/ui/tool_display.py index a2360e57c..4cfb7d0d3 100644 --- a/src/fast_agent/ui/tool_display.py +++ b/src/fast_agent/ui/tool_display.py @@ -53,7 +53,9 @@ def show_tool_result( for tool_cfg in skybridge_config.tools: if tool_cfg.tool_name == tool_name and tool_cfg.is_valid: is_skybridge_tool = True - skybridge_resource_uri = tool_cfg.resource_uri + skybridge_resource_uri = ( + str(tool_cfg.resource_uri) if tool_cfg.resource_uri is not None else None + ) break if result.isError: diff --git a/tests/integration/acp/test_acp_skills_manager.py b/tests/integration/acp/test_acp_skills_manager.py index 0b234ba5a..0042f91ba 100644 --- a/tests/integration/acp/test_acp_skills_manager.py +++ b/tests/integration/acp/test_acp_skills_manager.py @@ -4,31 +4,80 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Sequence, cast import pytest from fast_agent.acp.slash_commands import SlashCommandHandler from fast_agent.config import get_settings -from fast_agent.skills.registry import SkillManifest, format_skills_for_prompt if TYPE_CHECKING: from fast_agent.core.fastagent import AgentInstance + from fast_agent.skills.registry import SkillManifest, SkillRegistry + + +@dataclass +class StubAggregator: + """Stub aggregator for testing.""" + + async def get_server_instructions(self) -> dict[str, tuple[str | None, list[str]]]: + return {} @dataclass class SkillAgent: + """Skill agent that implements McpInstructionCapable protocol for testing.""" + name: str - instruction: str = "" + _instruction: str = "" + _instruction_template: str = "{{agentSkills}}" + _instruction_context: dict[str, str] = field(default_factory=dict) + _skill_manifests: list[SkillManifest] = field(default_factory=list) + _skill_registry: SkillRegistry | None = None + _aggregator: StubAggregator = field(default_factory=StubAggregator) message_history: list[Any] = field(default_factory=list) llm: Any = None - _skill_manifests: list[SkillManifest] = field(default_factory=list) - def set_skill_manifests(self, manifests: list[SkillManifest]) -> None: + @property + def instruction(self) -> str: + return self._instruction + + def set_instruction(self, instruction: str) -> None: + self._instruction = instruction + + @property + def instruction_template(self) -> str: + return self._instruction_template + + @property + def instruction_context(self) -> dict[str, str]: + return self._instruction_context + + def set_instruction_context(self, context: dict[str, str]) -> None: + self._instruction_context = context + + @property + def aggregator(self) -> StubAggregator: + return self._aggregator + + @property + def skill_manifests(self) -> Sequence[SkillManifest]: + return self._skill_manifests + + @property + def skill_registry(self) -> SkillRegistry | None: + return self._skill_registry + + @skill_registry.setter + def skill_registry(self, value: SkillRegistry | None) -> None: + self._skill_registry = value + + def set_skill_manifests(self, manifests: Sequence[SkillManifest]) -> None: self._skill_manifests = list(manifests) - async def rebuild_instruction_templates(self) -> None: - self.instruction = format_skills_for_prompt(self._skill_manifests) + @property + def has_filesystem_runtime(self) -> bool: + return False @dataclass diff --git a/tests/integration/acp/test_acp_slash_commands.py b/tests/integration/acp/test_acp_slash_commands.py index 50a0890ca..f9a78dbfd 100644 --- a/tests/integration/acp/test_acp_slash_commands.py +++ b/tests/integration/acp/test_acp_slash_commands.py @@ -13,6 +13,7 @@ from mcp.types import TextContent from fast_agent.acp.slash_commands import SlashCommandHandler +from fast_agent.agents.agent_types import AgentType from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -36,6 +37,8 @@ class StubAgent: llm: Any = None cleared: bool = False popped: bool = False + agent_type: AgentType = AgentType.BASIC + name: str = "test-agent" def clear(self, clear_prompts: bool = False) -> None: self.cleared = True diff --git a/tests/integration/acp/test_set_model_validation.py b/tests/integration/acp/test_set_model_validation.py new file mode 100644 index 000000000..54d9f201b --- /dev/null +++ b/tests/integration/acp/test_set_model_validation.py @@ -0,0 +1,221 @@ +"""Tests for /set-model validation in ACP mode.""" + +from __future__ import annotations + +import pytest + +from fast_agent.llm.hf_inference_lookup import ( + InferenceProvider, + InferenceProviderLookupResult, + InferenceProviderStatus, + ModelValidationResult, + validate_hf_model, +) + + +def _make_valid_model_lookup() -> InferenceProviderLookupResult: + """Create a lookup result for a valid model with providers.""" + return InferenceProviderLookupResult( + model_id="moonshotai/Kimi-K2-Instruct-0905", + exists=True, + providers=[ + InferenceProvider( + name="groq", + status=InferenceProviderStatus.LIVE, + provider_id="moonshotai/kimi-k2-instruct-0905", + task="conversational", + is_model_author=False, + ), + InferenceProvider( + name="together", + status=InferenceProviderStatus.LIVE, + provider_id="moonshotai/Kimi-K2-Instruct-0905", + task="conversational", + is_model_author=False, + ), + ], + ) + + +def _make_nonexistent_model_lookup(model_id: str) -> InferenceProviderLookupResult: + """Create a lookup result for a non-existent model.""" + return InferenceProviderLookupResult( + model_id=model_id, + exists=False, + providers=[], + error=f"Model '{model_id}' not found on HuggingFace", + ) + + +def _make_no_providers_lookup(model_id: str) -> InferenceProviderLookupResult: + """Create a lookup result for a model without providers.""" + return InferenceProviderLookupResult( + model_id=model_id, + exists=True, + providers=[], + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_rejects_nonexistent_model() -> None: + """Test that validation rejects models that don't exist on HuggingFace.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + return _make_nonexistent_model_lookup(model_id) + + result = await validate_hf_model( + "hf.fake-org/nonexistent-model", + lookup_fn=stub_lookup, + ) + + assert isinstance(result, ModelValidationResult) + assert result.valid is False + assert result.error is not None + assert "not found" in result.error + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_rejects_model_without_providers() -> None: + """Test that validation rejects models without inference providers.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + return _make_no_providers_lookup(model_id) + + result = await validate_hf_model( + "hf.some-org/model-without-providers", + lookup_fn=stub_lookup, + ) + + assert isinstance(result, ModelValidationResult) + assert result.valid is False + assert result.error is not None + assert "no inference providers" in result.error + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_accepts_valid_model_with_providers() -> None: + """Test that validation accepts models with inference providers.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + return _make_valid_model_lookup() + + result = await validate_hf_model( + "hf.moonshotai/Kimi-K2-Instruct-0905", + lookup_fn=stub_lookup, + ) + + assert isinstance(result, ModelValidationResult) + assert result.valid is True + assert result.error is None + assert result.display_message != "" + assert "Available providers" in result.display_message + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_accepts_model_with_provider_suffix() -> None: + """Test that validation works with model:provider format.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + # Model ID should have the provider suffix stripped + assert ":" not in model_id + return _make_valid_model_lookup() + + result = await validate_hf_model( + "hf.moonshotai/Kimi-K2-Instruct-0905:together", + lookup_fn=stub_lookup, + ) + + assert result.valid is True + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_skips_non_hf_models() -> None: + """Test that validation skips models without org/model format.""" + lookup_called = False + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + nonlocal lookup_called + lookup_called = True + return _make_valid_model_lookup() + + # Non-HF model (no slash) - should skip validation + result = await validate_hf_model("gpt-4o", lookup_fn=stub_lookup) + + assert result.valid is True + assert lookup_called is False # Lookup should not be called + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_handles_lookup_exception() -> None: + """Test that validation handles lookup failures gracefully.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + raise Exception("Network error") + + result = await validate_hf_model( + "hf.some-org/some-model", + lookup_fn=stub_lookup, + ) + + assert result.valid is False + assert result.error is not None + assert "Failed to validate" in result.error + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_validate_resolves_aliases_to_hf_models() -> None: + """Test that aliases like 'kimi' are resolved and show provider info.""" + from fast_agent.llm.model_factory import ModelFactory + + # Find an alias that resolves to an HF model + hf_alias = None + resolved_model = None + for alias, model in ModelFactory.MODEL_ALIASES.items(): + if model.startswith("hf."): + hf_alias = alias + resolved_model = model + break + + if hf_alias is None: + pytest.skip("No HF model aliases found in MODEL_ALIASES") + + # Extract the expected model ID from the resolved model + expected_model_id = resolved_model[3:] # Strip "hf." + if ":" in expected_model_id: + expected_model_id = expected_model_id.rsplit(":", 1)[0] + + lookup_called_with: list[str] = [] + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + lookup_called_with.append(model_id) + return InferenceProviderLookupResult( + model_id=model_id, + exists=True, + providers=[ + InferenceProvider( + name="groq", + status=InferenceProviderStatus.LIVE, + provider_id="test", + task="conversational", + is_model_author=False, + ), + ], + ) + + # Call with the alias (e.g., "kimi") and pass aliases dict + result = await validate_hf_model( + hf_alias, aliases=ModelFactory.MODEL_ALIASES, lookup_fn=stub_lookup + ) + + # Should have resolved the alias and looked up the HF model + assert result.valid is True + assert len(lookup_called_with) == 1 + assert lookup_called_with[0] == expected_model_id + assert "Available providers" in result.display_message diff --git a/tests/integration/api/test_prompt_commands.py b/tests/integration/api/test_prompt_commands.py index 246e91921..8332025a5 100644 --- a/tests/integration/api/test_prompt_commands.py +++ b/tests/integration/api/test_prompt_commands.py @@ -4,6 +4,7 @@ import pytest +from fast_agent.ui.command_payloads import SelectPromptCommand, is_command_payload from fast_agent.ui.enhanced_prompt import handle_special_commands @@ -11,26 +12,25 @@ async def test_command_handling_for_prompts(): """Test the command handling functions for /prompts and /prompt commands.""" # Test /prompts command after it's been pre-processed - # The pre-processed form of "/prompts" is {"select_prompt": True, "prompt_name": None} - result = await handle_special_commands({"select_prompt": True, "prompt_name": None}, True) - assert isinstance(result, dict), "Result should be a dictionary" - assert "select_prompt" in result, "Result should have select_prompt key" - assert result["select_prompt"] is True - assert "prompt_name" in result - assert result["prompt_name"] is None + # The pre-processed form of "/prompts" is a SelectPromptCommand dataclass + input_cmd = SelectPromptCommand(prompt_name=None, prompt_index=None) + result = await handle_special_commands(input_cmd, True) + assert is_command_payload(result), "Result should be a command payload" + assert isinstance(result, SelectPromptCommand) + assert result.prompt_name is None + assert result.prompt_index is None # Test /prompt command after pre-processing - # The pre-processed form is {"select_prompt": True, "prompt_index": 3} - result = await handle_special_commands({"select_prompt": True, "prompt_index": 3}, True) - assert isinstance(result, dict), "Result should be a dictionary" - assert "select_prompt" in result - assert "prompt_index" in result - assert result["prompt_index"] == 3 + # The pre-processed form is a SelectPromptCommand with prompt_index + input_cmd = SelectPromptCommand(prompt_name=None, prompt_index=3) + result = await handle_special_commands(input_cmd, True) + assert is_command_payload(result), "Result should be a command payload" + assert isinstance(result, SelectPromptCommand) + assert result.prompt_index == 3 # Test /prompt command after pre-processing # The pre-processed form is "SELECT_PROMPT:my-prompt" result = await handle_special_commands("SELECT_PROMPT:my-prompt", True) - assert isinstance(result, dict), "Result should be a dictionary" - assert "select_prompt" in result - assert "prompt_name" in result - assert result["prompt_name"] == "my-prompt" + assert is_command_payload(result), "Result should be a command payload" + assert isinstance(result, SelectPromptCommand) + assert result.prompt_name == "my-prompt" diff --git a/tests/unit/fast_agent/llm/test_hf_inference_lookup_unit.py b/tests/unit/fast_agent/llm/test_hf_inference_lookup_unit.py new file mode 100644 index 000000000..6e4082a72 --- /dev/null +++ b/tests/unit/fast_agent/llm/test_hf_inference_lookup_unit.py @@ -0,0 +1,183 @@ +"""Tests for HuggingFace inference provider lookup.""" + +from __future__ import annotations + +import pytest + +from fast_agent.llm.hf_inference_lookup import ( + InferenceProvider, + InferenceProviderLookupResult, + InferenceProviderStatus, + lookup_inference_providers, +) + + +@pytest.mark.asyncio +async def test_lookup_with_valid_model_and_providers() -> None: + """Test lookup returns providers for a valid model.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + return InferenceProviderLookupResult( + model_id=model_id, + exists=True, + providers=[ + InferenceProvider( + name="groq", + status=InferenceProviderStatus.LIVE, + provider_id="moonshotai/kimi-k2-instruct-0905", + task="conversational", + is_model_author=False, + ), + InferenceProvider( + name="together", + status=InferenceProviderStatus.LIVE, + provider_id="moonshotai/Kimi-K2-Instruct-0905", + task="conversational", + is_model_author=False, + ), + ], + ) + + result = await lookup_inference_providers( + "moonshotai/Kimi-K2-Instruct-0905", + lookup_fn=stub_lookup, + ) + + assert result.exists is True + assert result.has_providers is True + assert len(result.live_providers) == 2 + assert result.error is None + + +@pytest.mark.asyncio +async def test_lookup_with_nonexistent_model() -> None: + """Test lookup returns error for a non-existent model.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + return InferenceProviderLookupResult( + model_id=model_id, + exists=False, + providers=[], + error=f"Model '{model_id}' not found on HuggingFace", + ) + + result = await lookup_inference_providers( + "fake-org/nonexistent-model", + lookup_fn=stub_lookup, + ) + + assert result.exists is False + assert result.has_providers is False + assert result.error is not None + assert "not found" in result.error + + +@pytest.mark.asyncio +async def test_lookup_with_model_without_providers() -> None: + """Test lookup for a model that exists but has no providers.""" + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + return InferenceProviderLookupResult( + model_id=model_id, + exists=True, + providers=[], # Model exists but no providers + ) + + result = await lookup_inference_providers( + "some-org/model-without-providers", + lookup_fn=stub_lookup, + ) + + assert result.exists is True + assert result.has_providers is False + assert result.error is None + + +@pytest.mark.asyncio +async def test_lookup_strips_hf_prefix() -> None: + """Test that the lookup function receives normalized model ID.""" + received_model_ids: list[str] = [] + + async def stub_lookup(model_id: str) -> InferenceProviderLookupResult: + received_model_ids.append(model_id) + return InferenceProviderLookupResult( + model_id=model_id, + exists=True, + providers=[], + ) + + # The stub receives the model_id directly, so we can check what was passed + await lookup_inference_providers( + "hf.moonshotai/Kimi-K2-Instruct-0905", + lookup_fn=stub_lookup, + ) + + # The lookup_fn is called before normalization in the test stub path + assert received_model_ids[0] == "hf.moonshotai/Kimi-K2-Instruct-0905" + + +@pytest.mark.asyncio +async def test_lookup_result_format_provider_list() -> None: + """Test formatting of provider list.""" + result = InferenceProviderLookupResult( + model_id="test/model", + exists=True, + providers=[ + InferenceProvider( + name="groq", + status=InferenceProviderStatus.LIVE, + provider_id="test", + task="conversational", + is_model_author=False, + ), + InferenceProvider( + name="together", + status=InferenceProviderStatus.LIVE, + provider_id="test", + task="conversational", + is_model_author=False, + ), + InferenceProvider( + name="staging-provider", + status=InferenceProviderStatus.STAGING, + provider_id="test", + task="conversational", + is_model_author=False, + ), + ], + ) + + # Should only include live providers + provider_list = result.format_provider_list() + assert "groq" in provider_list + assert "together" in provider_list + assert "staging-provider" not in provider_list + + +@pytest.mark.asyncio +async def test_lookup_result_format_model_strings() -> None: + """Test formatting of model strings with provider suffixes.""" + result = InferenceProviderLookupResult( + model_id="moonshotai/Kimi-K2-Instruct", + exists=True, + providers=[ + InferenceProvider( + name="groq", + status=InferenceProviderStatus.LIVE, + provider_id="test", + task="conversational", + is_model_author=False, + ), + InferenceProvider( + name="together", + status=InferenceProviderStatus.LIVE, + provider_id="test", + task="conversational", + is_model_author=False, + ), + ], + ) + + model_strings = result.format_model_strings() + assert "moonshotai/Kimi-K2-Instruct:groq" in model_strings + assert "moonshotai/Kimi-K2-Instruct:together" in model_strings diff --git a/uv.lock b/uv.lock index 87764306a..046bf4474 100644 --- a/uv.lock +++ b/uv.lock @@ -532,16 +532,6 @@ azure = [ bedrock = [ { name = "boto3" }, ] -dev = [ - { name = "pre-commit" }, - { name = "pydantic" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-cov" }, - { name = "pyyaml" }, - { name = "ruamel-yaml" }, - { name = "ruff" }, -] tensorzero = [ { name = "tensorzero" }, ] @@ -551,6 +541,7 @@ textual = [ [package.dev-dependencies] dev = [ + { name = "boto3" }, { name = "ipdb" }, { name = "pre-commit" }, { name = "pydantic" }, @@ -585,31 +576,24 @@ requires-dist = [ { name = "opentelemetry-instrumentation-google-genai", specifier = ">=0.4b0" }, { name = "opentelemetry-instrumentation-mcp", marker = "python_full_version >= '3.10' and python_full_version < '4'", specifier = ">=0.49.5" }, { name = "opentelemetry-instrumentation-openai", marker = "python_full_version >= '3.10' and python_full_version < '4'", specifier = ">=0.49.5" }, - { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.0.1" }, { name = "prompt-toolkit", specifier = ">=3.0.52" }, { name = "pydantic", specifier = ">=2.10.4" }, - { name = "pydantic", marker = "extra == 'dev'", specifier = ">=2.10.4" }, { name = "pydantic-settings", specifier = ">=2.7.0" }, { name = "pyperclip", specifier = ">=1.9.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.1" }, - { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "python-frontmatter", specifier = ">=1.1.0" }, { name = "pyyaml", specifier = ">=6.0.2" }, - { name = "pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.2" }, { name = "rich", specifier = ">=14.2.0" }, - { name = "ruamel-yaml", marker = "extra == 'dev'", specifier = ">=0.18.0" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.4" }, { name = "tensorzero", marker = "extra == 'all-providers'", specifier = ">=2025.7.5" }, { name = "tensorzero", marker = "extra == 'tensorzero'", specifier = ">=2025.7.5" }, { name = "textual", marker = "extra == 'textual'", specifier = ">=6.2.1" }, { name = "tiktoken", specifier = ">=0.12.0" }, { name = "typer", specifier = ">=0.20.0" }, ] -provides-extras = ["azure", "bedrock", "tensorzero", "textual", "all-providers", "dev"] +provides-extras = ["azure", "bedrock", "tensorzero", "textual", "all-providers"] [package.metadata.requires-dev] dev = [ + { name = "boto3", specifier = ">=1.35.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pydantic", specifier = ">=2.10.4" }, From dec1d28a281bcb8d1745688afd5ca57d429458f6 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 19:27:16 +0000 Subject: [PATCH 06/15] type safety, stream fixes, glm-4.7 --- .../src/hf_inference_acp/agents.py | 12 +- src/fast_agent/agents/llm_agent.py | 7 +- src/fast_agent/agents/llm_decorator.py | 13 +- src/fast_agent/cli/commands/check_config.py | 26 +- src/fast_agent/cli/commands/quickstart.py | 4 +- src/fast_agent/core/logging/transport.py | 12 +- src/fast_agent/core/validation.py | 2 +- src/fast_agent/llm/model_database.py | 10 + src/fast_agent/llm/model_factory.py | 2 +- .../llm/provider/anthropic/llm_anthropic.py | 5 +- .../llm/provider/google/llm_google_native.py | 5 +- .../llm/provider/openai/llm_aliyun.py | 5 +- .../llm/provider/openai/llm_azure.py | 5 +- .../llm/provider/openai/llm_deepseek.py | 5 +- .../llm/provider/openai/llm_generic.py | 7 +- .../llm/provider/openai/llm_google_oai.py | 5 +- .../llm/provider/openai/llm_groq.py | 5 +- .../llm/provider/openai/llm_huggingface.py | 5 +- .../llm/provider/openai/llm_openai.py | 32 ++- .../llm/provider/openai/llm_openrouter.py | 5 +- .../provider/openai/llm_tensorzero_openai.py | 7 +- src/fast_agent/llm/provider/openai/llm_xai.py | 7 +- .../llm/provider/openai/responses.py | 5 +- src/fast_agent/llm/provider_types.py | 5 + .../mcp/mcp_agent_client_session.py | 11 +- src/fast_agent/ui/markdown_truncator.py | 6 + src/fast_agent/ui/streaming.py | 239 ++++++++++++------ .../llm/test_openai_stream_dedup.py | 21 ++ .../ui/test_markdown_truncator_streaming.py | 20 ++ .../ui/test_streaming_mode_switch.py | 68 +++++ .../ui/test_streaming_table_chunking.py | 55 ++++ typesafe.md | 33 +++ 32 files changed, 505 insertions(+), 144 deletions(-) create mode 100644 tests/unit/fast_agent/llm/test_openai_stream_dedup.py create mode 100644 tests/unit/fast_agent/ui/test_streaming_mode_switch.py create mode 100644 tests/unit/fast_agent/ui/test_streaming_table_chunking.py diff --git a/publish/hf-inference-acp/src/hf_inference_acp/agents.py b/publish/hf-inference-acp/src/hf_inference_acp/agents.py index 0ba57d750..f076cd5d7 100644 --- a/publish/hf-inference-acp/src/hf_inference_acp/agents.py +++ b/publish/hf-inference-acp/src/hf_inference_acp/agents.py @@ -504,26 +504,24 @@ async def _send_connect_update( # Rebuild system prompt to include fresh server instructions await _send_connect_update(title="Rebuilding system prompt…", status="in_progress") - await self.rebuild_instruction_templates() + await self._apply_instruction_templates() # Get available tools await _send_connect_update(title="Fetching available tools…", status="in_progress") tools_result = await self._aggregator.list_tools() tool_names = [t.name for t in tools_result.tools] if tools_result.tools else [] + # Send final progress update (but don't mark as completed yet - + # the return value serves as the completion signal) if tool_names: - preview = ", ".join(tool_names[:10]) - suffix = f" (+{len(tool_names) - 10} more)" if len(tool_names) > 10 else "" await _send_connect_update( - title="Connected (tools available)", + title=f"Connected ({len(tool_names)} tools)", status="completed", - message=f"Available tools: {preview}{suffix}", ) else: await _send_connect_update( - title="Connected (no tools found)", + title="Connected (no tools)", status="completed", - message="No tools available from the server.", ) if tool_names: diff --git a/src/fast_agent/agents/llm_agent.py b/src/fast_agent/agents/llm_agent.py index 55cc9cc1f..97f64555c 100644 --- a/src/fast_agent/agents/llm_agent.py +++ b/src/fast_agent/agents/llm_agent.py @@ -286,8 +286,9 @@ async def generate_impl( summary_text: Text | None = None if self._should_stream(): + llm = self._require_llm() display_name = self.name - display_model = self.llm.model_name if self.llm else None + display_model = llm.model_name remove_listener: Callable[[], None] | None = None remove_tool_listener: Callable[[], None] | None = None @@ -297,8 +298,8 @@ async def generate_impl( model=display_model, ) as stream_handle: try: - remove_listener = self.llm.add_stream_listener(stream_handle.update_chunk) - remove_tool_listener = self.llm.add_tool_stream_listener( + remove_listener = llm.add_stream_listener(stream_handle.update_chunk) + remove_tool_listener = llm.add_tool_stream_listener( stream_handle.handle_tool_event ) except Exception: diff --git a/src/fast_agent/agents/llm_decorator.py b/src/fast_agent/agents/llm_decorator.py index e0f9f2698..360ee6da6 100644 --- a/src/fast_agent/agents/llm_decorator.py +++ b/src/fast_agent/agents/llm_decorator.py @@ -11,6 +11,7 @@ Any, Callable, Mapping, + Self, Sequence, Type, TypeVar, @@ -285,7 +286,7 @@ def _clone_constructor_kwargs(self) -> dict[str, Any]: """Hook for subclasses/mixins to supply constructor kwargs when cloning.""" return {} - async def spawn_detached_instance(self, *, name: str | None = None) -> "LlmAgent": + async def spawn_detached_instance(self, *, name: str | None = None) -> Self: """Create a fresh agent instance with its own MCP/LLM stack.""" new_config = deepcopy(self.config) @@ -610,7 +611,7 @@ def _prepare_llm_call( ) -> _CallContext: """Normalize template/history handling for both generate and structured.""" sanitized_messages, summary = self._sanitize_messages_for_llm(messages) - final_request_params = self._llm.get_request_params(request_params) + final_request_params = self._require_llm().get_request_params(request_params) use_history = final_request_params.use_history if final_request_params else True call_params = final_request_params.model_copy() if final_request_params else None @@ -1032,6 +1033,12 @@ def usage_accumulator(self) -> UsageAccumulator | None: def llm(self) -> FastAgentLLMProtocol | None: return self._llm + def _require_llm(self) -> FastAgentLLMProtocol: + """Return the attached LLM, raising if not yet attached.""" + if self._llm is None: + raise RuntimeError(f"Agent '{self._name}' has no LLM attached") + return self._llm + # --- Default MCP-facing convenience methods (no-op for plain LLM agents) --- async def list_prompts(self, namespace: str | None = None) -> Mapping[str, list[Prompt]]: @@ -1076,7 +1083,7 @@ async def with_resource( @property def provider(self) -> Provider: - return self.llm.provider + return self._require_llm().provider def _merge_request_params( self, diff --git a/src/fast_agent/cli/commands/check_config.py b/src/fast_agent/cli/commands/check_config.py index 3de74faf9..0dbe9975b 100644 --- a/src/fast_agent/cli/commands/check_config.py +++ b/src/fast_agent/cli/commands/check_config.py @@ -82,7 +82,7 @@ def check_api_keys(secrets_summary: dict, config_summary: dict) -> dict: import os results = { - provider.value: {"env": "", "config": ""} + provider.config_name: {"env": "", "config": ""} for provider in Provider if provider != Provider.FAST_AGENT } @@ -97,18 +97,18 @@ def check_api_keys(secrets_summary: dict, config_summary: dict) -> dict: if config and "azure" in config.get("config", {}): config_azure = config["config"]["azure"] - for provider in results: + for provider_name in results: # Always check environment variables first - env_key_name = ProviderKeyManager.get_env_key_name(provider) + env_key_name = ProviderKeyManager.get_env_key_name(provider_name) env_key_value = os.environ.get(env_key_name) if env_key_value: if len(env_key_value) > 5: - results[provider]["env"] = f"...{env_key_value[-5:]}" + results[provider_name]["env"] = f"...{env_key_value[-5:]}" else: - results[provider]["env"] = "...***" + results[provider_name]["env"] = "...***" # Special handling for Azure: support api_key and DefaultAzureCredential - if provider == "azure": + if provider_name == "azure": # Prefer secrets if present, else fallback to config azure_cfg = {} if secrets_status == "parsed" and "azure" in secrets: @@ -119,17 +119,17 @@ def check_api_keys(secrets_summary: dict, config_summary: dict) -> dict: use_default_cred = azure_cfg.get("use_default_azure_credential", False) base_url = azure_cfg.get("base_url") if use_default_cred and base_url: - results[provider]["config"] = "DefaultAzureCredential" + results[provider_name]["config"] = "DefaultAzureCredential" continue # Check secrets file if it was parsed successfully if secrets_status == "parsed": - config_key = ProviderKeyManager.get_config_file_key(provider, secrets) + config_key = ProviderKeyManager.get_config_file_key(provider_name, secrets) if config_key and config_key != API_KEY_HINT_TEXT: if len(config_key) > 5: - results[provider]["config"] = f"...{config_key[-5:]}" + results[provider_name]["config"] = f"...{config_key[-5:]}" else: - results[provider]["config"] = "...***" + results[provider_name]["config"] = "...***" return results @@ -277,7 +277,9 @@ def get_config_summary(config_path: Path | None) -> dict: if server_info["url"] and len(server_info["url"]) > 60: server_info["url"] = server_info["url"][:57] + "..." - result["mcp_servers"].append(server_info) + mcp_servers = result["mcp_servers"] + assert isinstance(mcp_servers, list) + mcp_servers.append(server_info) # Skills directory override skills_cfg = config.get("skills") if isinstance(config, dict) else None @@ -711,7 +713,7 @@ def _truncate(text: str, length: int = 70) -> str: console.print("1. Add keys to fastagent.secrets.yaml") env_vars = ", ".join( [ - ProviderKeyManager.get_env_key_name(p.value) + ProviderKeyManager.get_env_key_name(p.config_name) for p in Provider if p != Provider.FAST_AGENT ] diff --git a/src/fast_agent/cli/commands/quickstart.py b/src/fast_agent/cli/commands/quickstart.py index f0f545df0..f10aa737b 100644 --- a/src/fast_agent/cli/commands/quickstart.py +++ b/src/fast_agent/cli/commands/quickstart.py @@ -189,6 +189,7 @@ def copy_example_files(example_type: str, target_dir: Path, force: bool = False) if use_as_file: source_path = stack.enter_context(as_file(source_dir)) # type: ignore else: + assert isinstance(source_dir, Path) source_path = source_dir if not source_path.exists(): @@ -502,7 +503,8 @@ def tensorzero( if use_as_file: source_path = stack.enter_context(as_file(source_dir)) # type: ignore else: - source_path = source_dir # type: ignore[assignment] + assert isinstance(source_dir, Path) + source_path = source_dir if not source_path.exists() or not source_path.is_dir(): console.print(f"[red]Error: Source project directory not found at '{source_path}'[/red]") diff --git a/src/fast_agent/core/logging/transport.py b/src/fast_agent/core/logging/transport.py index 71f210855..66fb646f1 100644 --- a/src/fast_agent/core/logging/transport.py +++ b/src/fast_agent/core/logging/transport.py @@ -182,7 +182,7 @@ class HTTPTransport(FilteredEventTransport): def __init__( self, endpoint: str, - headers: dict[str, str] = None, + headers: dict[str, str] | None = None, batch_size: int = 100, timeout: float = 5.0, event_filter: EventFilter | None = None, @@ -227,6 +227,7 @@ async def _flush(self) -> None: if not self._session: await self.start() + assert self._session is not None try: # Convert events to JSON-serializable dicts @@ -240,7 +241,7 @@ async def _flush(self) -> None: "data": self._serializer(event.data), "trace_id": event.trace_id, "span_id": event.span_id, - "context": event.context.dict() if event.context else None, + "context": event.context.model_dump() if event.context else None, } for event in self.batch ] @@ -328,7 +329,7 @@ async def stop(self) -> None: self._running = False # Try to process remaining items with a timeout - if not self._queue.empty(): + if self._queue is not None and not self._queue.empty(): try: # Give some time for remaining items to be processed await asyncio.wait_for(self._queue.join(), timeout=5.0) @@ -385,7 +386,8 @@ async def emit(self, event: Event) -> None: print(f"Error in transport.send_event: {e}") # Then queue for listeners - await self._queue.put(event) + if self._queue is not None: + await self._queue.put(event) def add_listener(self, name: str, listener: EventListener) -> None: """Add a listener to the event bus.""" @@ -401,6 +403,8 @@ async def _process_events(self) -> None: event = None try: # Use wait_for with a timeout to allow checking running state + if self._queue is None: + continue try: event = await asyncio.wait_for(self._queue.get(), timeout=0.1) except asyncio.TimeoutError: diff --git a/src/fast_agent/core/validation.py b/src/fast_agent/core/validation.py index d011302d4..874b16ec8 100644 --- a/src/fast_agent/core/validation.py +++ b/src/fast_agent/core/validation.py @@ -144,7 +144,7 @@ def get_dependencies( agents: dict[str, dict[str, Any]], visited: set, path: set, - agent_type: AgentType = None, + agent_type: AgentType | None = None, ) -> list[str]: """ Get dependencies for an agent in topological order. diff --git a/src/fast_agent/llm/model_database.py b/src/fast_agent/llm/model_database.py index e63cbdba9..31b23a9a9 100644 --- a/src/fast_agent/llm/model_database.py +++ b/src/fast_agent/llm/model_database.py @@ -221,6 +221,15 @@ class ModelDatabase: stream_mode="manual", ) + GLM_47 = ModelParameters( + context_window=202752, + max_output_tokens=65536, # default from https://docs.z.ai/guides/overview/concept-param#token-usage-calculation - max is 131072 + tokenizes=TEXT_ONLY, + json_mode="object", + reasoning="reasoning_content", + stream_mode="manual", + ) + HF_PROVIDER_DEEPSEEK31 = ModelParameters( context_window=163_800, max_output_tokens=8192, tokenizes=TEXT_ONLY ) @@ -338,6 +347,7 @@ class ModelDatabase: "openai/gpt-oss-120b": OPENAI_GPT_OSS_SERIES, # https://cookbook.openai.com/articles/openai-harmony "openai/gpt-oss-20b": OPENAI_GPT_OSS_SERIES, # tool/reasoning interleave guidance "zai-org/glm-4.6": GLM_46, + "zai-org/glm-4.7": GLM_47, "minimaxai/minimax-m2": GLM_46, "qwen/qwen3-next-80b-a3b-instruct": HF_PROVIDER_QWEN3_NEXT, "deepseek-ai/deepseek-v3.1": HF_PROVIDER_DEEPSEEK31, diff --git a/src/fast_agent/llm/model_factory.py b/src/fast_agent/llm/model_factory.py index 14c33aeb6..791c709a0 100644 --- a/src/fast_agent/llm/model_factory.py +++ b/src/fast_agent/llm/model_factory.py @@ -144,7 +144,7 @@ class ModelFactory: "kimi": "hf.moonshotai/Kimi-K2-Instruct-0905:groq", "gpt-oss": "hf.openai/gpt-oss-120b", "gpt-oss-20b": "hf.openai/gpt-oss-20b", - "glm": "hf.zai-org/GLM-4.6:cerebras", + "glm": "hf.zai-org/GLM-4.7:zai-org", "qwen3": "hf.Qwen/Qwen3-Next-80B-A3B-Instruct:together", "deepseek31": "hf.deepseek-ai/DeepSeek-V3.1", "kimithink": "hf.moonshotai/Kimi-K2-Thinking:together", diff --git a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py index 09bd58be5..36bb441d5 100644 --- a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py @@ -70,9 +70,10 @@ class AnthropicLLM(FastAgentLLM[MessageParam, Message]): FastAgentLLM.PARAM_MCP_METADATA, } - def __init__(self, *args, **kwargs) -> None: + def __init__(self, **kwargs) -> None: # Initialize logger - keep it simple without name reference - super().__init__(*args, provider=Provider.ANTHROPIC, **kwargs) + kwargs.pop("provider", None) + super().__init__(provider=Provider.ANTHROPIC, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize Anthropic-specific default parameters""" diff --git a/src/fast_agent/llm/provider/google/llm_google_native.py b/src/fast_agent/llm/provider/google/llm_google_native.py index 40a92b678..2a83bddb0 100644 --- a/src/fast_agent/llm/provider/google/llm_google_native.py +++ b/src/fast_agent/llm/provider/google/llm_google_native.py @@ -49,8 +49,9 @@ class GoogleNativeLLM(FastAgentLLM[types.Content, types.Content]): Google LLM provider using the native google.genai library. """ - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, provider=Provider.GOOGLE, **kwargs) + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.GOOGLE, **kwargs) # Initialize the converter self._converter = GoogleConverter() diff --git a/src/fast_agent/llm/provider/openai/llm_aliyun.py b/src/fast_agent/llm/provider/openai/llm_aliyun.py index e46310043..d0516f7ca 100644 --- a/src/fast_agent/llm/provider/openai/llm_aliyun.py +++ b/src/fast_agent/llm/provider/openai/llm_aliyun.py @@ -8,8 +8,9 @@ class AliyunLLM(GroqLLM): - def __init__(self, *args, **kwargs) -> None: - OpenAILLM.__init__(self, *args, provider=Provider.ALIYUN, **kwargs) + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + OpenAILLM.__init__(self, provider=Provider.ALIYUN, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize Aliyun-specific default parameters""" diff --git a/src/fast_agent/llm/provider/openai/llm_azure.py b/src/fast_agent/llm/provider/openai/llm_azure.py index 5b06f844a..0ccf05c94 100644 --- a/src/fast_agent/llm/provider/openai/llm_azure.py +++ b/src/fast_agent/llm/provider/openai/llm_azure.py @@ -27,9 +27,10 @@ class AzureOpenAILLM(OpenAILLM): Handles both API Key and DefaultAzureCredential authentication. """ - def __init__(self, provider: Provider = Provider.AZURE, *args, **kwargs): + def __init__(self, provider: Provider = Provider.AZURE, **kwargs): # Set provider to AZURE, pass through to base - super().__init__(provider=provider, *args, **kwargs) + kwargs.pop("provider", None) + super().__init__(provider=provider, **kwargs) # Context/config extraction context = getattr(self, "context", None) diff --git a/src/fast_agent/llm/provider/openai/llm_deepseek.py b/src/fast_agent/llm/provider/openai/llm_deepseek.py index 21309c838..770c3a3f1 100644 --- a/src/fast_agent/llm/provider/openai/llm_deepseek.py +++ b/src/fast_agent/llm/provider/openai/llm_deepseek.py @@ -16,8 +16,9 @@ class DeepSeekLLM(OpenAICompatibleLLM): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, provider=Provider.DEEPSEEK, **kwargs) + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.DEEPSEEK, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize Deepseek-specific default parameters""" diff --git a/src/fast_agent/llm/provider/openai/llm_generic.py b/src/fast_agent/llm/provider/openai/llm_generic.py index 57cfdf7aa..95703adba 100644 --- a/src/fast_agent/llm/provider/openai/llm_generic.py +++ b/src/fast_agent/llm/provider/openai/llm_generic.py @@ -10,10 +10,9 @@ class GenericLLM(OpenAILLM): - def __init__(self, *args, **kwargs) -> None: - super().__init__( - *args, provider=Provider.GENERIC, **kwargs - ) # Properly pass args and kwargs to parent + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.GENERIC, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize Generic parameters""" diff --git a/src/fast_agent/llm/provider/openai/llm_google_oai.py b/src/fast_agent/llm/provider/openai/llm_google_oai.py index 038493409..d424e088e 100644 --- a/src/fast_agent/llm/provider/openai/llm_google_oai.py +++ b/src/fast_agent/llm/provider/openai/llm_google_oai.py @@ -9,8 +9,9 @@ class GoogleOaiLLM(OpenAILLM): config_section = "google" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, provider=Provider.GOOGLE_OAI, **kwargs) + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.GOOGLE_OAI, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize Google OpenAI Compatibility default parameters""" diff --git a/src/fast_agent/llm/provider/openai/llm_groq.py b/src/fast_agent/llm/provider/openai/llm_groq.py index 886508af3..9f43b06b6 100644 --- a/src/fast_agent/llm/provider/openai/llm_groq.py +++ b/src/fast_agent/llm/provider/openai/llm_groq.py @@ -12,8 +12,9 @@ class GroqLLM(OpenAICompatibleLLM): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, provider=Provider.GROQ, **kwargs) + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.GROQ, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize Groq default parameters""" diff --git a/src/fast_agent/llm/provider/openai/llm_huggingface.py b/src/fast_agent/llm/provider/openai/llm_huggingface.py index 7a7629b9b..d38f65dca 100644 --- a/src/fast_agent/llm/provider/openai/llm_huggingface.py +++ b/src/fast_agent/llm/provider/openai/llm_huggingface.py @@ -9,9 +9,10 @@ class HuggingFaceLLM(OpenAICompatibleLLM): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, **kwargs) -> None: self._hf_provider_suffix: str | None = None - super().__init__(*args, provider=Provider.HUGGINGFACE, **kwargs) + kwargs.pop("provider", None) + super().__init__(provider=Provider.HUGGINGFACE, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize HuggingFace-specific default parameters""" diff --git a/src/fast_agent/llm/provider/openai/llm_openai.py b/src/fast_agent/llm/provider/openai/llm_openai.py index da2523e57..e84c78196 100644 --- a/src/fast_agent/llm/provider/openai/llm_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_openai.py @@ -104,8 +104,9 @@ class OpenAILLM(FastAgentLLM[ChatCompletionMessageParam, ChatCompletionMessage]) FastAgentLLM.PARAM_STOP_SEQUENCES, } - def __init__(self, provider: Provider = Provider.OPENAI, *args, **kwargs) -> None: - super().__init__(*args, provider=provider, **kwargs) + def __init__(self, provider: Provider = Provider.OPENAI, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=provider, **kwargs) # Initialize logger with name if available self.logger = get_logger(f"{__name__}.{self.name}" if self.name else __name__) @@ -447,6 +448,15 @@ def _close_reasoning_if_active(self, reasoning_active: bool) -> bool: """Return reasoning state; kept for symmetry.""" return False if reasoning_active else reasoning_active + @staticmethod + def _extract_incremental_delta(delta: str, cumulative: str) -> tuple[str, str]: + """Return the incremental portion of a possibly cumulative stream delta.""" + if not delta: + return "", cumulative + if cumulative and delta.startswith(cumulative): + return delta[len(cumulative) :], delta + return delta, cumulative + delta + async def _process_stream( self, stream, @@ -472,6 +482,7 @@ async def _process_stream( # Use ChatCompletionStreamState helper for accumulation (OpenAI only) state = ChatCompletionStreamState() + cumulative_content = "" # Track tool call state for stream events tool_call_started: dict[int, dict[str, Any]] = {} @@ -512,8 +523,13 @@ async def _process_stream( # Handle text content streaming if delta.content: + incremental, cumulative_content = self._extract_incremental_delta( + delta.content, cumulative_content + ) + if not incremental: + continue estimated_tokens, reasoning_active = self._emit_text_delta( - content=delta.content, + content=incremental, model=model, estimated_tokens=estimated_tokens, streams_arguments=streams_arguments, @@ -621,6 +637,7 @@ async def _process_stream_manual( # Manual accumulation of response data accumulated_content = "" + cumulative_content = "" role = "assistant" tool_calls_map = {} # Use a map to accumulate tool calls by index function_call = None @@ -665,14 +682,19 @@ async def _process_stream_manual( # Handle text content streaming if delta.content: + incremental, cumulative_content = self._extract_incremental_delta( + delta.content, cumulative_content + ) + if not incremental: + continue estimated_tokens, reasoning_active = self._emit_text_delta( - content=delta.content, + content=incremental, model=model, estimated_tokens=estimated_tokens, streams_arguments=streams_arguments, reasoning_active=reasoning_active, ) - accumulated_content += delta.content + accumulated_content += incremental # Fire "stop" event when tool calls complete if choice.finish_reason == "tool_calls": diff --git a/src/fast_agent/llm/provider/openai/llm_openrouter.py b/src/fast_agent/llm/provider/openai/llm_openrouter.py index de2c9b968..81029e3c4 100644 --- a/src/fast_agent/llm/provider/openai/llm_openrouter.py +++ b/src/fast_agent/llm/provider/openai/llm_openrouter.py @@ -12,8 +12,9 @@ class OpenRouterLLM(OpenAILLM): """Augmented LLM provider for OpenRouter, using an OpenAI-compatible API.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, provider=Provider.OPENROUTER, **kwargs) + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.OPENROUTER, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize OpenRouter-specific default parameters.""" diff --git a/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py b/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py index 2bc37122c..e9074f208 100644 --- a/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py @@ -14,18 +14,17 @@ class TensorZeroOpenAILLM(OpenAILLM): features, such as system template variables and custom parameters. """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, **kwargs) -> None: """ Initializes the TensorZeroOpenAIAugmentedLLM. Args: - *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ self._t0_episode_id = kwargs.pop("episode_id", None) self._t0_function_name = kwargs.get("model", "") - - super().__init__(*args, provider=Provider.TENSORZERO, **kwargs) + kwargs.pop("provider", None) + super().__init__(provider=Provider.TENSORZERO, **kwargs) self.logger.info("TensorZeroOpenAILLM initialized.") def _initialize_default_params(self, kwargs: dict) -> RequestParams: diff --git a/src/fast_agent/llm/provider/openai/llm_xai.py b/src/fast_agent/llm/provider/openai/llm_xai.py index 72ef67fb6..126a6915d 100644 --- a/src/fast_agent/llm/provider/openai/llm_xai.py +++ b/src/fast_agent/llm/provider/openai/llm_xai.py @@ -9,10 +9,9 @@ class XAILLM(OpenAILLM): - def __init__(self, *args, **kwargs) -> None: - super().__init__( - *args, provider=Provider.XAI, **kwargs - ) # Properly pass args and kwargs to parent + def __init__(self, **kwargs) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.XAI, **kwargs) def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize xAI parameters""" diff --git a/src/fast_agent/llm/provider/openai/responses.py b/src/fast_agent/llm/provider/openai/responses.py index 99475a068..23160bcbc 100644 --- a/src/fast_agent/llm/provider/openai/responses.py +++ b/src/fast_agent/llm/provider/openai/responses.py @@ -41,8 +41,9 @@ class ResponsesLLM(FastAgentLLM[ChatCompletionMessageParam, ChatCompletionMessag # OpenAI-specific parameter exclusions - def __init__(self, provider=Provider.RESPONSES, *args, **kwargs): - super().__init__(*args, provider=provider, **kwargs) + def __init__(self, provider=Provider.RESPONSES, **kwargs): + kwargs.pop("provider", None) + super().__init__(provider=provider, **kwargs) async def _responses_client(self) -> AsyncOpenAI: return AsyncOpenAI(api_key=self._api_key()) diff --git a/src/fast_agent/llm/provider_types.py b/src/fast_agent/llm/provider_types.py index db1f76adf..e12d127e2 100644 --- a/src/fast_agent/llm/provider_types.py +++ b/src/fast_agent/llm/provider_types.py @@ -16,6 +16,11 @@ def __new__(cls, config_name, display_name=None): obj.display_name = display_name or config_name.title() return obj + @property + def config_name(self) -> str: + """Return the provider's config name (typed accessor for _value_).""" + return self._value_ # type: ignore[return-value] + ANTHROPIC = ("anthropic", "Anthropic") DEEPSEEK = ("deepseek", "Deepseek") FAST_AGENT = ("fast-agent", "fast-agent-internal") diff --git a/src/fast_agent/mcp/mcp_agent_client_session.py b/src/fast_agent/mcp/mcp_agent_client_session.py index d9451f0a8..1085ebfa6 100644 --- a/src/fast_agent/mcp/mcp_agent_client_session.py +++ b/src/fast_agent/mcp/mcp_agent_client_session.py @@ -72,7 +72,7 @@ class MCPAgentClientSession(ClientSession, ContextDependent): Developers can extend this class to add more custom functionality as needed """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, read_stream, write_stream, read_timeout=None, **kwargs) -> None: # Extract server_name if provided in kwargs from importlib.metadata import version @@ -172,8 +172,15 @@ def __init__(self, *args, **kwargs) -> None: else: self.effective_elicitation_mode = "none" + # Pop parameters we're explicitly setting to avoid duplicates + kwargs.pop("list_roots_callback", None) + kwargs.pop("sampling_callback", None) + kwargs.pop("client_info", None) + kwargs.pop("elicitation_callback", None) super().__init__( - *args, + read_stream, + write_stream, + read_timeout, **kwargs, list_roots_callback=list_roots_cb, sampling_callback=sampling_cb, diff --git a/src/fast_agent/ui/markdown_truncator.py b/src/fast_agent/ui/markdown_truncator.py index 1c7e61212..a38394eaf 100644 --- a/src/fast_agent/ui/markdown_truncator.py +++ b/src/fast_agent/ui/markdown_truncator.py @@ -725,6 +725,12 @@ def _ensure_table_header_if_needed(self, original_text: str, truncated_text: str if truncation_pos >= table.thead_end_pos: # Header completely scrolled off - prepend it header_text = "\n".join(table.header_lines) + "\n" + truncated_lines = truncated_text.splitlines() + header_lines = [line.rstrip() for line in table.header_lines] + if len(truncated_lines) >= len(header_lines): + candidate = [line.rstrip() for line in truncated_lines[: len(header_lines)]] + if candidate == header_lines: + return truncated_text return header_text + truncated_text else: # Header still on screen diff --git a/src/fast_agent/ui/streaming.py b/src/fast_agent/ui/streaming.py index b41141ec7..430639398 100644 --- a/src/fast_agent/ui/streaming.py +++ b/src/fast_agent/ui/streaming.py @@ -72,6 +72,8 @@ def __init__( self._highlight_index = highlight_index self._max_item_length = max_item_length self._use_plain_text = use_plain_text + self._preferred_plain_text = use_plain_text + self._plain_text_override_count = 0 self._header_left = header_left self._header_right = header_right self._progress_display = progress_display @@ -122,6 +124,8 @@ def __init__( self._reasoning_parser = ReasoningStreamParser() self._styled_buffer: list[tuple[str, bool]] = [] self._has_reasoning = False + self._reasoning_active = False + self._tool_active = False if self._async_mode and self._loop and self._queue is not None: self._worker_task = self._loop.create_task(self._render_worker()) @@ -243,6 +247,129 @@ def _switch_to_plain_text(self, style: str | None = "dim") -> None: self._plain_text_style = style self._convert_literal_newlines = True + def _switch_to_markdown(self) -> None: + self._use_plain_text = False + self._plain_text_style = None + self._convert_literal_newlines = False + self._pending_literal_backslashes = "" + + def _insert_mode_switch_newline(self) -> None: + if self._pending_table_row: + return + if not self._buffer: + return + if self._buffer[-1].endswith("\n"): + return + self._buffer.append("\n") + if self._has_reasoning: + self._styled_buffer.append(("\n", False)) + + def _set_use_plain_text(self, use_plain_text: bool, *, insert_newline: bool) -> None: + if use_plain_text == self._use_plain_text: + return + if insert_newline: + self._insert_mode_switch_newline() + if use_plain_text: + self._switch_to_plain_text(style=None) + else: + self._switch_to_markdown() + + def _begin_plain_text_override(self) -> None: + self._plain_text_override_count += 1 + if self._plain_text_override_count == 1: + self._set_use_plain_text(True, insert_newline=True) + + def _end_plain_text_override(self) -> None: + if self._plain_text_override_count == 0: + return + self._plain_text_override_count -= 1 + if self._plain_text_override_count == 0: + self._set_use_plain_text(self._preferred_plain_text, insert_newline=True) + + def _begin_reasoning_mode(self) -> None: + if self._reasoning_active: + return + self._reasoning_active = True + if self._buffer and not self._styled_buffer: + self._styled_buffer.append(("".join(self._buffer), False)) + self._has_reasoning = True + self._begin_plain_text_override() + + def _end_reasoning_mode(self) -> None: + if not self._reasoning_active: + return + self._reasoning_active = False + self._end_plain_text_override() + + def _begin_tool_mode(self) -> None: + if self._tool_active: + return + self._tool_active = True + self._begin_plain_text_override() + + def _end_tool_mode(self) -> None: + if not self._tool_active: + return + self._tool_active = False + self._end_plain_text_override() + + def _append_plain_text(self, text: str, *, is_reasoning: bool | None = None) -> bool: + processed = text + if self._convert_literal_newlines: + processed = self._decode_literal_newlines(processed) + if not processed: + return False + processed = self._wrap_plain_chunk(processed) + if self._pending_table_row: + self._buffer.append(self._pending_table_row) + self._pending_table_row = "" + self._buffer.append(processed) + if self._has_reasoning: + self._styled_buffer.append((processed, bool(is_reasoning))) + return True + + def _append_text_in_current_mode(self, text: str) -> bool: + if not text: + return False + if self._use_plain_text: + return self._append_plain_text(text) + + text_so_far = "".join(self._buffer) + ends_with_newline = text_so_far.endswith("\n") + lines = text_so_far.split("\n") if text_so_far else [] + last_line = "" if ends_with_newline else (lines[-1] if lines else "") + currently_in_table = last_line.strip().startswith("|") + if self._pending_table_row: + if "\n" not in text: + self._pending_table_row += text + return False + text = self._pending_table_row + text + self._pending_table_row = "" + + starts_table_row = text.lstrip().startswith("|") + if "\n" not in text and (currently_in_table or starts_table_row): + pending_seed = "" + if currently_in_table: + split_index = text_so_far.rfind("\n") + if split_index == -1: + pending_seed = text_so_far + self._buffer = [] + else: + pending_seed = text_so_far[split_index + 1 :] + prefix = text_so_far[: split_index + 1] + self._buffer = [prefix] if prefix else [] + self._pending_table_row = pending_seed + text + return False + + if self._pending_table_row: + self._buffer.append(self._pending_table_row) + self._pending_table_row = "" + + self._buffer.append(text) + if self._has_reasoning: + self._styled_buffer.append((text, False)) + return True + def finalize(self, _message: "PromptMessageExtended | str") -> None: if not self._active or self._finalized: return @@ -443,8 +570,7 @@ def _process_reasoning_chunk(self, chunk: str) -> bool: ) if not should_process and not self._has_reasoning: return False - - self._switch_to_plain_text(style=None) + previous_in_think = self._reasoning_parser.in_think segments: list[ReasoningSegment] = [] if chunk: segments = self._reasoning_parser.feed(chunk) @@ -453,45 +579,46 @@ def _process_reasoning_chunk(self, chunk: str) -> bool: if not segments: return False - - self._has_reasoning = True + handled = False + emitted_non_thinking = False for segment in segments: - processed = segment.text - if self._convert_literal_newlines: - processed = self._decode_literal_newlines(processed) - if not processed: - continue - processed = self._wrap_plain_chunk(processed) - if self._pending_table_row: - self._buffer.append(self._pending_table_row) - self._pending_table_row = "" - self._buffer.append(processed) - self._styled_buffer.append((processed, segment.is_thinking)) - - return True + if segment.is_thinking: + self._begin_reasoning_mode() + self._append_plain_text(segment.text, is_reasoning=True) + handled = True + else: + if self._reasoning_active: + self._end_reasoning_mode() + emitted_non_thinking = True + self._append_text_in_current_mode(segment.text) + handled = True + + if ( + previous_in_think + and not self._reasoning_parser.in_think + and self._reasoning_active + and not emitted_non_thinking + ): + self._end_reasoning_mode() + + return handled def _handle_stream_chunk(self, chunk: StreamChunk) -> bool: """Process a typed stream chunk with explicit reasoning flag.""" if not chunk.text: return False + if not chunk.is_reasoning and self._process_reasoning_chunk(chunk.text): + return True - self._switch_to_plain_text(style=None) - - processed = chunk.text - if self._convert_literal_newlines: - processed = self._decode_literal_newlines(processed) - if not processed: - return False - processed = self._wrap_plain_chunk(processed) - if self._pending_table_row: - self._buffer.append(self._pending_table_row) - self._pending_table_row = "" - self._buffer.append(processed) - self._styled_buffer.append((processed, chunk.is_reasoning)) if chunk.is_reasoning: - self._has_reasoning = True - return True + self._begin_reasoning_mode() + return self._append_plain_text(chunk.text, is_reasoning=True) + + if self._reasoning_active: + self._end_reasoning_mode() + + return self._append_text_in_current_mode(chunk.text) def _handle_chunk(self, chunk: str) -> bool: if not chunk: @@ -499,35 +626,7 @@ def _handle_chunk(self, chunk: str) -> bool: if self._process_reasoning_chunk(chunk): return True - - if self._use_plain_text: - if self._convert_literal_newlines: - chunk = self._decode_literal_newlines(chunk) - if not chunk: - if self._pending_table_row: - self._buffer.append(self._pending_table_row) - self._pending_table_row = "" - return False - chunk = self._wrap_plain_chunk(chunk) - if self._pending_table_row: - self._buffer.append(self._pending_table_row) - self._pending_table_row = "" - else: - text_so_far = "".join(self._buffer) - lines = text_so_far.strip().split("\n") - last_line = lines[-1] if lines else "" - currently_in_table = last_line.strip().startswith("|") - - if currently_in_table and "\n" not in chunk: - self._pending_table_row += chunk - return False - - if self._pending_table_row: - self._buffer.append(self._pending_table_row) - self._pending_table_row = "" - - self._buffer.append(chunk) - return True + return self._append_text_in_current_mode(chunk) def _slice_styled_segments(self, target_text: str) -> list[tuple[str, bool]]: """Trim styled buffer to the tail matching the provided text length.""" @@ -720,13 +819,10 @@ def handle_tool_event(self, event_type: str, info: dict[str, Any] | None = None) tool_name = info.get("tool_name", "unknown") if info else "unknown" if event_type == "start": - if streams_arguments: - self._switch_to_plain_text() - self.update(f"\n→ Calling {tool_name}\n") - else: + self._begin_tool_mode() + if not streams_arguments: self._pause_progress_display() - self._switch_to_plain_text() - self.update(f"\n→ Calling {tool_name}\n") + self.update(f"→ Calling {tool_name}\n") return if event_type == "delta": if streams_arguments and info and "chunk" in info: @@ -734,12 +830,9 @@ def handle_tool_event(self, event_type: str, info: dict[str, Any] | None = None) elif event_type == "text": self._pause_progress_display() elif event_type == "stop": - if streams_arguments: - self.update("\n") - self.close() - else: - self.update("\n") - self.close() + self._end_tool_mode() + if not streams_arguments: + self._resume_progress_display() except Exception as exc: logger.warning( "Error handling tool event", diff --git a/tests/unit/fast_agent/llm/test_openai_stream_dedup.py b/tests/unit/fast_agent/llm/test_openai_stream_dedup.py new file mode 100644 index 000000000..1cea3f9e0 --- /dev/null +++ b/tests/unit/fast_agent/llm/test_openai_stream_dedup.py @@ -0,0 +1,21 @@ +from fast_agent.llm.provider.openai.llm_openai import OpenAILLM + + +def test_extract_incremental_delta_with_cumulative_content() -> None: + delta, cumulative = OpenAILLM._extract_incremental_delta("Hello, world", "") + assert delta == "Hello, world" + assert cumulative == "Hello, world" + + delta, cumulative = OpenAILLM._extract_incremental_delta("Hello, world!", cumulative) + assert delta == "!" + assert cumulative == "Hello, world!" + + +def test_extract_incremental_delta_with_non_cumulative_content() -> None: + delta, cumulative = OpenAILLM._extract_incremental_delta("Part 1", "") + assert delta == "Part 1" + assert cumulative == "Part 1" + + delta, cumulative = OpenAILLM._extract_incremental_delta("Part 2", cumulative) + assert delta == "Part 2" + assert cumulative == "Part 1Part 2" diff --git a/tests/unit/fast_agent/ui/test_markdown_truncator_streaming.py b/tests/unit/fast_agent/ui/test_markdown_truncator_streaming.py index c553d3c32..ae87da94d 100644 --- a/tests/unit/fast_agent/ui/test_markdown_truncator_streaming.py +++ b/tests/unit/fast_agent/ui/test_markdown_truncator_streaming.py @@ -259,3 +259,23 @@ def test_streaming_truncation_indented_code_block() -> None: assert truncated.strip(), f"no content produced for height={height}" assert truncated.lstrip().startswith("```"), "expected synthetic fence for indented block" + + +def test_streaming_truncation_avoids_duplicate_table_header() -> None: + truncator = MarkdownTruncator(target_height_ratio=0.5) + original = ( + "Intro\n" + "| Mission | Date |\n" + "| --- | --- |\n" + "| Apollo 11 | 1969 |\n" + "| Apollo 12 | 1969 |\n" + ) + + truncated = ( + "| Mission | Date |\n" + "| --- | --- |\n" + "| Apollo 12 | 1969 |\n" + ) + + result = truncator._ensure_table_header_if_needed(original, truncated) + assert result.count("| Mission | Date |") == 1 diff --git a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py new file mode 100644 index 000000000..7c8130661 --- /dev/null +++ b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py @@ -0,0 +1,68 @@ +from fast_agent.config import Settings +from fast_agent.llm.stream_types import StreamChunk +from fast_agent.ui.console_display import ConsoleDisplay, _StreamingMessageHandle + + +def _make_handle(streaming_mode: str = "markdown") -> _StreamingMessageHandle: + settings = Settings() + settings.logger.streaming = streaming_mode + display = ConsoleDisplay(settings) + return _StreamingMessageHandle( + display=display, + bottom_items=None, + highlight_index=None, + max_item_length=None, + use_plain_text=streaming_mode == "plain", + header_left="", + header_right="", + progress_display=None, + ) + + +def test_reasoning_stream_switches_back_to_markdown() -> None: + handle = _make_handle("markdown") + + handle._handle_stream_chunk(StreamChunk("Intro")) + assert handle._use_plain_text is False + + handle._handle_stream_chunk(StreamChunk("Thinking", is_reasoning=True)) + assert handle._use_plain_text is True + assert handle._reasoning_active is True + + handle._handle_stream_chunk(StreamChunk("Answer")) + assert handle._use_plain_text is False + assert handle._reasoning_active is False + + text = "".join(handle._buffer) + intro_idx = text.find("Intro") + thinking_idx = text.find("Thinking") + answer_idx = text.find("Answer") + assert intro_idx != -1 + assert thinking_idx != -1 + assert answer_idx != -1 + assert "\n" in text[intro_idx + len("Intro") : thinking_idx] + assert "\n" in text[thinking_idx + len("Thinking") : answer_idx] + + +def test_tool_mode_switches_back_to_markdown() -> None: + handle = _make_handle("markdown") + + handle._handle_chunk("Intro") + handle._begin_tool_mode() + assert handle._use_plain_text is True + + handle._handle_chunk("Calling tool") + handle._end_tool_mode() + assert handle._use_plain_text is False + + handle._handle_chunk("Result") + + text = "".join(handle._buffer) + intro_idx = text.find("Intro") + tool_idx = text.find("Calling tool") + result_idx = text.find("Result") + assert intro_idx != -1 + assert tool_idx != -1 + assert result_idx != -1 + assert "\n" in text[intro_idx + len("Intro") : tool_idx] + assert "\n" in text[tool_idx + len("Calling tool") : result_idx] diff --git a/tests/unit/fast_agent/ui/test_streaming_table_chunking.py b/tests/unit/fast_agent/ui/test_streaming_table_chunking.py new file mode 100644 index 000000000..52a571bc1 --- /dev/null +++ b/tests/unit/fast_agent/ui/test_streaming_table_chunking.py @@ -0,0 +1,55 @@ +from fast_agent.config import Settings +from fast_agent.llm.stream_types import StreamChunk +from fast_agent.ui.console_display import ConsoleDisplay, _StreamingMessageHandle + + +def _make_handle() -> _StreamingMessageHandle: + settings = Settings() + settings.logger.streaming = "markdown" + display = ConsoleDisplay(settings) + return _StreamingMessageHandle( + display=display, + bottom_items=None, + highlight_index=None, + max_item_length=None, + use_plain_text=False, + header_left="", + header_right="", + progress_display=None, + ) + + +def test_table_rows_do_not_duplicate_when_streaming_in_parts() -> None: + handle = _make_handle() + + chunks = ["| Mission | ", "Landing Date |", "\n"] + for chunk in chunks: + handle._handle_chunk(chunk) + + text = "".join(handle._buffer) + assert text == "".join(chunks) + + +def test_table_rows_do_not_duplicate_when_reasoning_interrupts() -> None: + handle = _make_handle() + + handle._handle_chunk("| Mission ") + handle._handle_stream_chunk(StreamChunk("thinking", is_reasoning=True)) + handle._handle_stream_chunk(StreamChunk(" done", is_reasoning=False)) + handle._handle_chunk("Mission | | Landing Date |\n") + + text = "".join(handle._buffer) + assert text.count("| Mission Mission |") == 0 + assert text.count("| Mission ") == 1 + + +def test_table_pending_row_not_duplicated_after_reasoning() -> None: + handle = _make_handle() + + handle._handle_stream_chunk(StreamChunk("thinking", is_reasoning=True)) + handle._handle_stream_chunk(StreamChunk(" |", is_reasoning=False)) + assert handle._pending_table_row == " |" + + handle._handle_stream_chunk(StreamChunk(" Fact |\n", is_reasoning=False)) + text = "".join(handle._buffer) + assert text.endswith(" | Fact |\n") diff --git a/typesafe.md b/typesafe.md index 7cfca9c6d..2d1750602 100644 --- a/typesafe.md +++ b/typesafe.md @@ -110,6 +110,39 @@ Python 3.13+ and current typing guidance. 7. **Enforce**: add `ty check` to CI once warnings are near-zero; tighten rules to `error` as we converge. +## **IMPORTANT: Refactoring `*args, **kwargs` Signatures** + +When replacing `def func(*args, **kwargs)` with explicit parameters to satisfy the type checker, +**always verify all call sites first**. The `*args` pattern accepts any number of positional +arguments, and removing it can silently break callers that pass positional args you didn't capture. + +**Before:** +```python +def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) +``` + +**After (WRONG - missed `read_timeout`):** +```python +def __init__(self, read_stream, write_stream, **kwargs) -> None: + super().__init__(read_stream, write_stream, **kwargs) +``` + +**After (CORRECT):** +```python +def __init__(self, read_stream, write_stream, read_timeout=None, **kwargs) -> None: + super().__init__(read_stream, write_stream, read_timeout, **kwargs) +``` + +**Checklist before removing `*args`:** +1. Grep for all call sites of the function/class +2. Check if any caller passes positional arguments beyond what you're capturing +3. Check type hints or protocols that define expected signatures (e.g., `Callable[[A, B, C], R]`) +4. When in doubt, keep `*args` and pop known params from it, or add explicit params for all + positional args callers use + +This error is insidious because it causes runtime `TypeError` failures, not type-check failures. + ## Decision Log (initial) - We will use `ty: ignore[rule]` over bare `ty: ignore` and avoid `type: ignore` unless an external From c21be4549a8854662ca10bca24db4cc54bcf8280 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 19:37:10 +0000 Subject: [PATCH 07/15] types and tests --- src/fast_agent/core/executor/executor.py | 2 +- src/fast_agent/llm/internal/passthrough.py | 2 +- src/fast_agent/mcp/elicitation_handlers.py | 15 +-- src/fast_agent/mcp/mcp_connection_manager.py | 14 +-- src/fast_agent/mcp/stdio_tracking_simple.py | 4 +- .../ui/test_streaming_mode_switch.py | 102 +++++++++++------- .../ui/test_streaming_table_chunking.py | 79 ++++++++++---- 7 files changed, 146 insertions(+), 72 deletions(-) diff --git a/src/fast_agent/core/executor/executor.py b/src/fast_agent/core/executor/executor.py index bf6842220..1ae5d1802 100644 --- a/src/fast_agent/core/executor/executor.py +++ b/src/fast_agent/core/executor/executor.py @@ -52,7 +52,7 @@ def __init__( self, engine: str, config: ExecutorConfig | None = None, - signal_bus: SignalHandler = None, + signal_bus: SignalHandler | None = None, context: Union["Context", None] = None, **kwargs, ) -> None: diff --git a/src/fast_agent/llm/internal/passthrough.py b/src/fast_agent/llm/internal/passthrough.py index 612c492e8..bb896a612 100644 --- a/src/fast_agent/llm/internal/passthrough.py +++ b/src/fast_agent/llm/internal/passthrough.py @@ -30,7 +30,7 @@ class PassthroughLLM(FastAgentLLM): """ def __init__( - self, provider=Provider.FAST_AGENT, name: str = "Passthrough", **kwargs: dict[str, Any] + self, provider=Provider.FAST_AGENT, name: str = "Passthrough", **kwargs: Any ) -> None: super().__init__(name=name, provider=provider, **kwargs) self.logger = get_logger(__name__) diff --git a/src/fast_agent/mcp/elicitation_handlers.py b/src/fast_agent/mcp/elicitation_handlers.py index c4d375ce6..3d308ffb3 100644 --- a/src/fast_agent/mcp/elicitation_handlers.py +++ b/src/fast_agent/mcp/elicitation_handlers.py @@ -61,15 +61,17 @@ async def forms_elicitation_handler( agent_name = "Unknown Agent" # Create human input request + # Note: requestedSchema is only present on ElicitRequestFormParams, not ElicitRequestURLParams + requested_schema = getattr(params, "requestedSchema", None) request = HumanInputRequest( prompt=params.message, - description=f"Schema: {params.requestedSchema}" if params.requestedSchema else None, + description=f"Schema: {requested_schema}" if requested_schema else None, request_id=f"elicit_{id(params)}", metadata={ "agent_name": agent_name, "server_name": server_name, "elicitation": True, - "requested_schema": params.requestedSchema, + "requested_schema": requested_schema, }, ) @@ -98,20 +100,21 @@ async def forms_elicitation_handler( try: from fast_agent.human_input.elicitation_state import elicitation_state - elicitation_state.disable_server(server_name) + if server_name is not None: + elicitation_state.disable_server(server_name) except Exception: # Do not fail the flow if state update fails pass return ElicitResult(action="cancel") # Parse response based on schema if provided - if params.requestedSchema: + if requested_schema: # Check if the response is already JSON (from our form) try: # Try to parse as JSON first (from schema-driven form) content = json.loads(response_data) # Validate that all required fields are present - required_fields = params.requestedSchema.get("required", []) + required_fields = requested_schema.get("required", []) for field in required_fields: if field not in content: logger.warning(f"Missing required field '{field}' in elicitation response") @@ -119,7 +122,7 @@ async def forms_elicitation_handler( except json.JSONDecodeError: # Not JSON, try to handle as simple text response # This is a fallback for simple schemas or text-based responses - properties = params.requestedSchema.get("properties", {}) + properties = requested_schema.get("properties", {}) if len(properties) == 1: # Single field schema - try to parse based on type field_name = next(iter(properties)) diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index 244c26882..eaca96615 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -5,6 +5,7 @@ import asyncio import traceback from datetime import timedelta +from contextlib import AbstractAsyncContextManager from typing import TYPE_CHECKING, AsyncGenerator, Callable, Union import httpx @@ -112,19 +113,15 @@ def __init__( server_config: MCPServerSettings, transport_context_factory: Callable[ [], - AsyncGenerator[ + AbstractAsyncContextManager[ tuple[ MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage], GetSessionIdCallback | None, - ], - None, + ] ], ], - client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], - ClientSession, - ], + client_session_factory: Callable[..., ClientSession], ) -> None: self.server_name = server_name self.server_config = server_config @@ -274,6 +271,7 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: server_conn.session_id = "local" server_conn.create_session(read_stream, write_stream) + assert server_conn.session is not None try: async with server_conn.session: @@ -405,6 +403,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Then close the task group if it's active if self._task_group_active: + assert self._task_group is not None await self._task_group.__aexit__(exc_type, exc_val, exc_tb) self._task_group_active = False self._task_group = None @@ -596,6 +595,7 @@ def channel_hook(event): return self.running_servers[server_name] self.running_servers[server_name] = server_conn + assert self._tg is not None self._tg.start_soon(_server_lifecycle_task, server_conn) logger.info(f"{server_name}: Up and running with a persistent connection!") diff --git a/src/fast_agent/mcp/stdio_tracking_simple.py b/src/fast_agent/mcp/stdio_tracking_simple.py index 43bba2e60..5fffb1d72 100644 --- a/src/fast_agent/mcp/stdio_tracking_simple.py +++ b/src/fast_agent/mcp/stdio_tracking_simple.py @@ -2,7 +2,7 @@ import logging from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, AsyncGenerator, Callable +from typing import TYPE_CHECKING, AsyncGenerator, Callable, TextIO from mcp.client.stdio import StdioServerParameters, stdio_client @@ -22,7 +22,7 @@ async def tracking_stdio_client( server_params: StdioServerParameters, *, channel_hook: ChannelHook | None = None, - errlog: Callable[[str], None] | None = None, + errlog: TextIO | None = None, ) -> AsyncGenerator[ tuple[ObjectReceiveStream[SessionMessage | Exception], ObjectSendStream[SessionMessage]], None ]: diff --git a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py index 7c8130661..6a73a186a 100644 --- a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py +++ b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py @@ -1,8 +1,30 @@ from fast_agent.config import Settings +from fast_agent.ui import console from fast_agent.llm.stream_types import StreamChunk from fast_agent.ui.console_display import ConsoleDisplay, _StreamingMessageHandle +def _set_console_size(width: int = 80, height: int = 24) -> tuple[object | None, object | None]: + original_width = getattr(console.console, "_width", None) + original_height = getattr(console.console, "_height", None) + console.console._width = width + console.console._height = height + return original_width, original_height + + +def _restore_console_size(original_width: object | None, original_height: object | None) -> None: + if original_width is None: + if hasattr(console.console, "_width"): + delattr(console.console, "_width") + else: + console.console._width = original_width + if original_height is None: + if hasattr(console.console, "_height"): + delattr(console.console, "_height") + else: + console.console._height = original_height + + def _make_handle(streaming_mode: str = "markdown") -> _StreamingMessageHandle: settings = Settings() settings.logger.streaming = streaming_mode @@ -20,49 +42,55 @@ def _make_handle(streaming_mode: str = "markdown") -> _StreamingMessageHandle: def test_reasoning_stream_switches_back_to_markdown() -> None: + original_width, original_height = _set_console_size() handle = _make_handle("markdown") + try: + handle._handle_stream_chunk(StreamChunk("Intro")) + assert handle._use_plain_text is False - handle._handle_stream_chunk(StreamChunk("Intro")) - assert handle._use_plain_text is False + handle._handle_stream_chunk(StreamChunk("Thinking", is_reasoning=True)) + assert handle._use_plain_text is True + assert handle._reasoning_active is True - handle._handle_stream_chunk(StreamChunk("Thinking", is_reasoning=True)) - assert handle._use_plain_text is True - assert handle._reasoning_active is True + handle._handle_stream_chunk(StreamChunk("Answer")) + assert handle._use_plain_text is False + assert handle._reasoning_active is False - handle._handle_stream_chunk(StreamChunk("Answer")) - assert handle._use_plain_text is False - assert handle._reasoning_active is False - - text = "".join(handle._buffer) - intro_idx = text.find("Intro") - thinking_idx = text.find("Thinking") - answer_idx = text.find("Answer") - assert intro_idx != -1 - assert thinking_idx != -1 - assert answer_idx != -1 - assert "\n" in text[intro_idx + len("Intro") : thinking_idx] - assert "\n" in text[thinking_idx + len("Thinking") : answer_idx] + text = "".join(handle._buffer) + intro_idx = text.find("Intro") + thinking_idx = text.find("Thinking") + answer_idx = text.find("Answer") + assert intro_idx != -1 + assert thinking_idx != -1 + assert answer_idx != -1 + assert "\n" in text[intro_idx + len("Intro") : thinking_idx] + assert "\n" in text[thinking_idx + len("Thinking") : answer_idx] + finally: + _restore_console_size(original_width, original_height) def test_tool_mode_switches_back_to_markdown() -> None: + original_width, original_height = _set_console_size() handle = _make_handle("markdown") + try: + handle._handle_chunk("Intro") + handle._begin_tool_mode() + assert handle._use_plain_text is True + + handle._handle_chunk("Calling tool") + handle._end_tool_mode() + assert handle._use_plain_text is False + + handle._handle_chunk("Result") - handle._handle_chunk("Intro") - handle._begin_tool_mode() - assert handle._use_plain_text is True - - handle._handle_chunk("Calling tool") - handle._end_tool_mode() - assert handle._use_plain_text is False - - handle._handle_chunk("Result") - - text = "".join(handle._buffer) - intro_idx = text.find("Intro") - tool_idx = text.find("Calling tool") - result_idx = text.find("Result") - assert intro_idx != -1 - assert tool_idx != -1 - assert result_idx != -1 - assert "\n" in text[intro_idx + len("Intro") : tool_idx] - assert "\n" in text[tool_idx + len("Calling tool") : result_idx] + text = "".join(handle._buffer) + intro_idx = text.find("Intro") + tool_idx = text.find("Calling tool") + result_idx = text.find("Result") + assert intro_idx != -1 + assert tool_idx != -1 + assert result_idx != -1 + assert "\n" in text[intro_idx + len("Intro") : tool_idx] + assert "\n" in text[tool_idx + len("Calling tool") : result_idx] + finally: + _restore_console_size(original_width, original_height) diff --git a/tests/unit/fast_agent/ui/test_streaming_table_chunking.py b/tests/unit/fast_agent/ui/test_streaming_table_chunking.py index 52a571bc1..944ec7893 100644 --- a/tests/unit/fast_agent/ui/test_streaming_table_chunking.py +++ b/tests/unit/fast_agent/ui/test_streaming_table_chunking.py @@ -1,5 +1,6 @@ from fast_agent.config import Settings from fast_agent.llm.stream_types import StreamChunk +from fast_agent.ui import console from fast_agent.ui.console_display import ConsoleDisplay, _StreamingMessageHandle @@ -20,36 +21,78 @@ def _make_handle() -> _StreamingMessageHandle: def test_table_rows_do_not_duplicate_when_streaming_in_parts() -> None: + original_width = getattr(console.console, "_width", None) + original_height = getattr(console.console, "_height", None) + console.console._width = 80 + console.console._height = 24 handle = _make_handle() - chunks = ["| Mission | ", "Landing Date |", "\n"] - for chunk in chunks: - handle._handle_chunk(chunk) + try: + chunks = ["| Mission | ", "Landing Date |", "\n"] + for chunk in chunks: + handle._handle_chunk(chunk) - text = "".join(handle._buffer) - assert text == "".join(chunks) + text = "".join(handle._buffer) + assert text == "".join(chunks) + finally: + if original_width is None: + delattr(console.console, "_width") + else: + console.console._width = original_width + if original_height is None: + delattr(console.console, "_height") + else: + console.console._height = original_height def test_table_rows_do_not_duplicate_when_reasoning_interrupts() -> None: + original_width = getattr(console.console, "_width", None) + original_height = getattr(console.console, "_height", None) + console.console._width = 80 + console.console._height = 24 handle = _make_handle() - handle._handle_chunk("| Mission ") - handle._handle_stream_chunk(StreamChunk("thinking", is_reasoning=True)) - handle._handle_stream_chunk(StreamChunk(" done", is_reasoning=False)) - handle._handle_chunk("Mission | | Landing Date |\n") + try: + handle._handle_chunk("| Mission ") + handle._handle_stream_chunk(StreamChunk("thinking", is_reasoning=True)) + handle._handle_stream_chunk(StreamChunk(" done", is_reasoning=False)) + handle._handle_chunk("Mission | | Landing Date |\n") - text = "".join(handle._buffer) - assert text.count("| Mission Mission |") == 0 - assert text.count("| Mission ") == 1 + text = "".join(handle._buffer) + assert text.count("| Mission Mission |") == 0 + assert text.count("| Mission ") == 1 + finally: + if original_width is None: + delattr(console.console, "_width") + else: + console.console._width = original_width + if original_height is None: + delattr(console.console, "_height") + else: + console.console._height = original_height def test_table_pending_row_not_duplicated_after_reasoning() -> None: + original_width = getattr(console.console, "_width", None) + original_height = getattr(console.console, "_height", None) + console.console._width = 80 + console.console._height = 24 handle = _make_handle() - handle._handle_stream_chunk(StreamChunk("thinking", is_reasoning=True)) - handle._handle_stream_chunk(StreamChunk(" |", is_reasoning=False)) - assert handle._pending_table_row == " |" + try: + handle._handle_stream_chunk(StreamChunk("thinking", is_reasoning=True)) + handle._handle_stream_chunk(StreamChunk(" |", is_reasoning=False)) + assert handle._pending_table_row == " |" - handle._handle_stream_chunk(StreamChunk(" Fact |\n", is_reasoning=False)) - text = "".join(handle._buffer) - assert text.endswith(" | Fact |\n") + handle._handle_stream_chunk(StreamChunk(" Fact |\n", is_reasoning=False)) + text = "".join(handle._buffer) + assert text.endswith(" | Fact |\n") + finally: + if original_width is None: + delattr(console.console, "_width") + else: + console.console._width = original_width + if original_height is None: + delattr(console.console, "_height") + else: + console.console._height = original_height From 7c7d687b0ba53372b223bf6a0e8c547ecc477b4f Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 19:40:44 +0000 Subject: [PATCH 08/15] lint --- src/fast_agent/mcp/mcp_connection_manager.py | 4 ++-- tests/unit/fast_agent/ui/test_streaming_mode_switch.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index eaca96615..211b8daf1 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -4,9 +4,9 @@ import asyncio import traceback -from datetime import timedelta from contextlib import AbstractAsyncContextManager -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Union +from datetime import timedelta +from typing import TYPE_CHECKING, Callable, Union import httpx from anyio import Event, Lock, create_task_group diff --git a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py index 6a73a186a..b1a42fab1 100644 --- a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py +++ b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py @@ -1,6 +1,6 @@ from fast_agent.config import Settings -from fast_agent.ui import console from fast_agent.llm.stream_types import StreamChunk +from fast_agent.ui import console from fast_agent.ui.console_display import ConsoleDisplay, _StreamingMessageHandle From 4bd59704fd2e5441faa852eb7e9ad1332eb50ac5 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 20:47:32 +0000 Subject: [PATCH 09/15] type safety, streaming fix --- src/fast_agent/llm/model_database.py | 12 + .../llm/provider/anthropic/llm_anthropic.py | 215 +++++++++++------- .../multipart_converter_anthropic.py | 51 +++-- src/fast_agent/ui/console_display.py | 5 + src/fast_agent/ui/rich_progress.py | 7 +- 5 files changed, 182 insertions(+), 108 deletions(-) diff --git a/src/fast_agent/llm/model_database.py b/src/fast_agent/llm/model_database.py index 31b23a9a9..0ae3c9d28 100644 --- a/src/fast_agent/llm/model_database.py +++ b/src/fast_agent/llm/model_database.py @@ -31,6 +31,9 @@ class ModelParameters(BaseModel): stream_mode: Literal["openai", "manual"] = "openai" """Determines how streaming deltas should be processed.""" + system_role: None | str = "system" + """Role to use for the System Prompt""" + class ModelDatabase: """Centralized model configuration database""" @@ -169,6 +172,15 @@ class ModelDatabase: context_window=65536, max_output_tokens=32768, tokenizes=TEXT_ONLY ) + DEEPSEEK_V_32 = ModelParameters( + context_window=65536, + max_output_tokens=32768, + tokenizes=TEXT_ONLY, + json_mode="object", + reasoning="gpt-oss", + system_role="developer", + ) + DEEPSEEK_DISTILL = ModelParameters( context_window=131072, max_output_tokens=131072, diff --git a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py index 36bb441d5..fc02191d9 100644 --- a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py @@ -1,14 +1,23 @@ import asyncio import json +import os +from datetime import datetime +from pathlib import Path from typing import Any, Type, Union, cast from anthropic import APIError, AsyncAnthropic, AuthenticationError from anthropic.lib.streaming import AsyncMessageStream from anthropic.types import ( + InputJSONDelta, Message, MessageParam, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, + RawMessageDeltaEvent, TextBlock, TextBlockParam, + TextDelta, ToolParam, ToolUseBlock, ToolUseBlockParam, @@ -38,6 +47,7 @@ AnthropicConverter, ) from fast_agent.llm.provider_types import Provider +from fast_agent.llm.stream_types import StreamChunk from fast_agent.llm.usage_tracking import TurnUsage from fast_agent.mcp.helpers.content_helpers import text_content from fast_agent.types import PromptMessageExtended @@ -46,12 +56,42 @@ DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-0" STRUCTURED_OUTPUT_TOOL_NAME = "return_structured_output" +# Stream capture mode - when enabled, saves all streaming chunks to files for debugging +# Set FAST_AGENT_LLM_TRACE=1 (or any non-empty value) to enable +STREAM_CAPTURE_ENABLED = bool(os.environ.get("FAST_AGENT_LLM_TRACE")) +STREAM_CAPTURE_DIR = Path("stream-debug") + # Type alias for system field - can be string or list of text blocks with cache control SystemParam = Union[str, list[TextBlockParam]] logger = get_logger(__name__) +def _stream_capture_filename(turn: int) -> Path | None: + """Generate filename for stream capture. Returns None if capture is disabled.""" + if not STREAM_CAPTURE_ENABLED: + return None + STREAM_CAPTURE_DIR.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return STREAM_CAPTURE_DIR / f"anthropic_{timestamp}_turn{turn}" + + +def _save_stream_chunk(filename_base: Path | None, chunk: Any) -> None: + """Save a streaming chunk to file when capture mode is enabled.""" + if not filename_base: + return + try: + chunk_file = filename_base.with_name(f"{filename_base.name}.jsonl") + try: + payload: Any = chunk.model_dump() + except Exception: + payload = {"type": type(chunk).__name__, "str": str(chunk)} + with open(chunk_file, "a") as f: + f.write(json.dumps(payload) + "\n") + except Exception as e: + logger.debug(f"Failed to save stream chunk: {e}") + + class AnthropicLLM(FastAgentLLM[MessageParam, Message]): CONVERSATION_CACHE_WALK_DISTANCE = 6 MAX_CONVERSATION_CACHE_BLOCKS = 2 @@ -239,6 +279,7 @@ async def _process_stream( self, stream: AsyncMessageStream, model: str, + capture_filename: Path | None = None, ) -> Message: """Process the streaming response and display real-time token usage.""" # Track estimated output tokens by counting text chunks @@ -249,70 +290,68 @@ async def _process_stream( # Process the raw event stream to get token counts # Cancellation is handled via asyncio.Task.cancel() which raises CancelledError async for event in stream: - if ( - event.type == "content_block_start" - and hasattr(event, "content_block") - and getattr(event.content_block, "type", None) == "tool_use" - ): - content_block = event.content_block - tool_streams[event.index] = { - "name": content_block.name, - "id": content_block.id, - "buffer": [], - } - self._notify_tool_stream_listeners( - "start", - { - "tool_name": content_block.name, - "tool_use_id": content_block.id, - "index": event.index, - "streams_arguments": False, # Anthropic doesn't stream arguments - }, - ) - self.logger.info( - "Model started streaming tool input", - data={ - "progress_action": ProgressAction.CALLING_TOOL, - "agent_name": self.name, - "model": model, - "tool_name": content_block.name, - "tool_use_id": content_block.id, - "tool_event": "start", - }, - ) - continue + # Save chunk if stream capture is enabled + _save_stream_chunk(capture_filename, event) - if ( - event.type == "content_block_delta" - and hasattr(event, "delta") - and event.delta.type == "input_json_delta" - ): - info = tool_streams.get(event.index) - if info is not None: - chunk = event.delta.partial_json or "" - info["buffer"].append(chunk) - preview = chunk if len(chunk) <= 80 else chunk[:77] + "..." + if isinstance(event, RawContentBlockStartEvent): + content_block = event.content_block + if isinstance(content_block, ToolUseBlock): + tool_streams[event.index] = { + "name": content_block.name, + "id": content_block.id, + "buffer": [], + } self._notify_tool_stream_listeners( - "delta", + "start", { - "tool_name": info.get("name"), - "tool_use_id": info.get("id"), + "tool_name": content_block.name, + "tool_use_id": content_block.id, "index": event.index, - "chunk": chunk, - "streams_arguments": False, + "streams_arguments": False, # Anthropic doesn't stream arguments }, ) - self.logger.debug( - "Streaming tool input delta", + self.logger.info( + "Model started streaming tool input", data={ - "tool_name": info.get("name"), - "tool_use_id": info.get("id"), - "chunk": preview, + "progress_action": ProgressAction.CALLING_TOOL, + "agent_name": self.name, + "model": model, + "tool_name": content_block.name, + "tool_use_id": content_block.id, + "tool_event": "start", }, ) - continue - - if event.type == "content_block_stop" and event.index in tool_streams: + continue + + if isinstance(event, RawContentBlockDeltaEvent): + delta = event.delta + if isinstance(delta, InputJSONDelta): + info = tool_streams.get(event.index) + if info is not None: + chunk = delta.partial_json or "" + info["buffer"].append(chunk) + preview = chunk if len(chunk) <= 80 else chunk[:77] + "..." + self._notify_tool_stream_listeners( + "delta", + { + "tool_name": info.get("name"), + "tool_use_id": info.get("id"), + "index": event.index, + "chunk": chunk, + "streams_arguments": False, + }, + ) + self.logger.debug( + "Streaming tool input delta", + data={ + "tool_name": info.get("name"), + "tool_use_id": info.get("id"), + "chunk": preview, + }, + ) + continue + + if isinstance(event, RawContentBlockStopEvent) and event.index in tool_streams: info = tool_streams.pop(event.index) preview_raw = "".join(info.get("buffer", [])) if preview_raw: @@ -350,30 +389,28 @@ async def _process_stream( continue # Count tokens in real-time from content_block_delta events - if ( - event.type == "content_block_delta" - and hasattr(event, "delta") - and event.delta.type == "text_delta" - ): - # Use base class method for token estimation and progress emission - estimated_tokens = self._update_streaming_progress( - event.delta.text, model, estimated_tokens - ) - self._notify_tool_stream_listeners( - "text", - { - "chunk": event.delta.text, - "index": event.index, - "streams_arguments": False, - }, - ) + if isinstance(event, RawContentBlockDeltaEvent): + delta = event.delta + if isinstance(delta, TextDelta): + # Notify stream listeners for UI streaming + self._notify_stream_listeners( + StreamChunk(text=delta.text, is_reasoning=False) + ) + # Use base class method for token estimation and progress emission + estimated_tokens = self._update_streaming_progress( + delta.text, model, estimated_tokens + ) + self._notify_tool_stream_listeners( + "text", + { + "chunk": delta.text, + "index": event.index, + "streams_arguments": False, + }, + ) # Also check for final message_delta events with actual usage info - elif ( - event.type == "message_delta" - and hasattr(event, "usage") - and event.usage.output_tokens - ): + elif isinstance(event, RawMessageDeltaEvent) and event.usage.output_tokens: actual_tokens = event.usage.output_tokens # Emit final progress with actual token count token_str = str(actual_tokens).rjust(5) @@ -401,10 +438,10 @@ async def _process_stream( raise # Re-raise to be handled by _anthropic_completion except Exception as error: logger.error("Unexpected error during Anthropic stream processing", exc_info=error) - # Convert to APIError for consistent handling - raise APIError(f"Stream processing error: {str(error)}") from error + # Re-raise for consistent handling - caller handles the error + raise - def _stream_failure_response(self, error: APIError, model_name: str) -> PromptMessageExtended: + def _stream_failure_response(self, error: Exception, model_name: str) -> PromptMessageExtended: """Convert streaming API errors into a graceful assistant reply.""" provider_label = ( @@ -572,11 +609,15 @@ async def _anthropic_completion( self._apply_cache_control_to_message(messages[idx]) logger.debug(f"{arguments}") + + # Generate stream capture filename once (before streaming starts) + capture_filename = _stream_capture_filename(self.chat_turn()) + # Use streaming API with helper try: async with anthropic.messages.stream(**arguments) as stream: # Process the stream - response = await self._process_stream(stream, model) + response = await self._process_stream(stream, model, capture_filename) except asyncio.CancelledError as e: reason = str(e) if e.args else "cancelled" logger.info(f"Anthropic completion cancelled: {reason}") @@ -612,9 +653,7 @@ async def _anthropic_completion( # This path shouldn't be reached anymore since we handle APIError above, # but keeping for backward compatibility logger.error(f"Unexpected error type: {type(response).__name__}", exc_info=response) - return self._stream_failure_response( - APIError(f"Unexpected error: {str(response)}"), model - ) + return self._stream_failure_response(response, model) logger.debug( f"{model} response:", @@ -640,7 +679,7 @@ async def _anthropic_completion( case "tool_use": stop_reason = LlmStopReason.TOOL_USE tool_uses: list[ToolUseBlock] = [ - c for c in response.content if c.type == "tool_use" + c for c in response.content if isinstance(c, ToolUseBlock) ] if structured_model and self._is_structured_output_request(tool_uses): stop_reason, structured_blocks = await self._handle_structured_output_response( @@ -722,7 +761,7 @@ async def _apply_prompt_provider_specific_structured( ) for content in result.content: - if content.type == "text": + if isinstance(content, TextContent): try: data = json.loads(content.text) parsed_model = model(**data) @@ -759,9 +798,9 @@ def convert_message_to_message_param(cls, message: Message, **kwargs) -> Message content = [] for content_block in message.content: - if content_block.type == "text": + if isinstance(content_block, TextBlock): content.append(TextBlock(type="text", text=content_block.text)) - elif content_block.type == "tool_use": + elif isinstance(content_block, ToolUseBlock): content.append( ToolUseBlockParam( type="tool_use", diff --git a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py index 82626f70b..9f5325786 100644 --- a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py @@ -1,5 +1,5 @@ import re -from typing import Sequence, Union +from typing import Literal, Sequence, Union, cast from urllib.parse import urlparse from anthropic.types import ( @@ -174,36 +174,44 @@ def _convert_content_items( if is_text_content(content_item): # Handle text content text = get_text(content_item) - anthropic_blocks.append(TextBlockParam(type="text", text=text)) + if text: + anthropic_blocks.append(TextBlockParam(type="text", text=text)) elif is_image_content(content_item): - # Handle image content - image_content = content_item # type: ImageContent + # Handle image content - cast needed for ty type narrowing + image_content = cast("ImageContent", content_item) + mime_type = image_content.mimeType or "" # Check if image MIME type is supported - if not AnthropicConverter._is_supported_image_type(image_content.mimeType): + if not AnthropicConverter._is_supported_image_type(mime_type): data_size = len(image_content.data) if image_content.data else 0 anthropic_blocks.append( TextBlockParam( type="text", - text=f"Image with unsupported format '{image_content.mimeType}' ({data_size} bytes)", + text=f"Image with unsupported format '{mime_type}' ({data_size} bytes)", ) ) else: image_data = get_image_data(image_content) - anthropic_blocks.append( - ImageBlockParam( - type="image", - source=Base64ImageSourceParam( - type="base64", - media_type=image_content.mimeType, - data=image_data, - ), + if image_data and mime_type in SUPPORTED_IMAGE_MIME_TYPES: + anthropic_blocks.append( + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", + media_type=cast( + "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']", + mime_type, + ), + data=image_data, + ), + ) ) - ) elif is_resource_content(content_item): - # Handle embedded resource - block = AnthropicConverter._convert_embedded_resource(content_item, document_mode) + # Handle embedded resource - cast needed for ty type narrowing + block = AnthropicConverter._convert_embedded_resource( + cast("EmbeddedResource", content_item), document_mode + ) anthropic_blocks.append(block) return anthropic_blocks @@ -258,7 +266,12 @@ def _convert_embedded_resource( return ImageBlockParam( type="image", source=Base64ImageSourceParam( - type="base64", media_type=mime_type, data=image_data + type="base64", + media_type=cast( + "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']", + mime_type, + ), + data=image_data, ), ) @@ -343,7 +356,7 @@ def _determine_mime_type( return resource.mimeType if resource.uri: - return guess_mime_type(resource.uri.serialize_url) + return guess_mime_type(str(resource.uri)) if hasattr(resource, "blob"): return "application/octet-stream" diff --git a/src/fast_agent/ui/console_display.py b/src/fast_agent/ui/console_display.py index c1d48fa06..77c70a697 100644 --- a/src/fast_agent/ui/console_display.py +++ b/src/fast_agent/ui/console_display.py @@ -778,10 +778,15 @@ def streaming_assistant_message( header_right=right_info, progress_display=progress_display, ) + # Pause progress display BEFORE yielding to prevent race condition with Anthropic + # (Anthropic's stream context manager may start events during __aenter__) + progress_display.pause() try: yield handle finally: handle.close() + # Resume progress display - must be explicit since we paused externally + progress_display.resume() def _display_mermaid_diagrams(self, diagrams: list[MermaidDiagram]) -> None: """Display mermaid diagram links.""" diff --git a/src/fast_agent/ui/rich_progress.py b/src/fast_agent/ui/rich_progress.py index f8582d613..a6e1ccbab 100644 --- a/src/fast_agent/ui/rich_progress.py +++ b/src/fast_agent/ui/rich_progress.py @@ -38,11 +38,12 @@ def start(self) -> None: def stop(self) -> None: """Stop and clear the progress display.""" + # Set paused first to prevent race with incoming updates + self._paused = True # Hide all tasks before stopping (like pause does) for task in self._progress.tasks: task.visible = False self._progress.stop() - self._paused = True def pause(self) -> None: """Pause the progress display.""" @@ -102,6 +103,10 @@ def _get_action_style(self, action: ProgressAction) -> str: def update(self, event: ProgressEvent) -> None: """Update the progress display with a new event.""" + # Skip updates when display is paused (e.g., during streaming) + if self._paused: + return + task_name = event.agent_name or "default" # Create new task if needed From c26f0258b8c9194fc1bbbc57979efc8f3e46681f Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:11:45 +0000 Subject: [PATCH 10/15] type saftey --- .../mcp/mcp_agent_client_session.py | 64 +++++++++---------- src/fast_agent/mcp/sampling.py | 27 ++++---- 2 files changed, 43 insertions(+), 48 deletions(-) diff --git a/src/fast_agent/mcp/mcp_agent_client_session.py b/src/fast_agent/mcp/mcp_agent_client_session.py index 1085ebfa6..c5ff62786 100644 --- a/src/fast_agent/mcp/mcp_agent_client_session.py +++ b/src/fast_agent/mcp/mcp_agent_client_session.py @@ -4,19 +4,20 @@ """ from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from mcp import ClientSession, ServerNotification +from mcp.shared.context import RequestContext from mcp.shared.message import MessageMetadata from mcp.shared.session import ( ProgressFnT, ReceiveResultT, - SendRequestT, ) from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, + ClientRequest, GetPromptRequest, GetPromptRequestParams, GetPromptResult, @@ -28,7 +29,7 @@ Root, ToolListChangedNotification, ) -from pydantic import FileUrl +from pydantic import AnyUrl, FileUrl from fast_agent.context_dependent import ContextDependent from fast_agent.core.logging.logger import get_logger @@ -42,10 +43,10 @@ logger = get_logger(__name__) -async def list_roots(ctx: ClientSession) -> ListRootsResult: +async def list_roots(context: RequestContext[ClientSession, None]) -> ListRootsResult: """List roots callback that will be called by the MCP library.""" - if server_config := get_server_config(ctx): + if server_config := get_server_config(context.session): if server_config.roots: roots = [ Root( @@ -99,8 +100,8 @@ def __init__(self, read_stream, write_stream, read_timeout=None, **kwargs) -> No # Track the effective elicitation mode for diagnostics self.effective_elicitation_mode: str | None = "none" - version = version("fast-agent-mcp") or "dev" - fast_agent: Implementation = Implementation(name="fast-agent-mcp", version=version) + fast_agent_version = version("fast-agent-mcp") or "dev" + fast_agent: Implementation = Implementation(name="fast-agent-mcp", version=fast_agent_version) if self.server_config and self.server_config.implementation: fast_agent = self.server_config.implementation @@ -202,7 +203,7 @@ def _should_enable_auto_sampling(self) -> bool: async def send_request( self, - request: SendRequestT, + request: ClientRequest, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata | None = None, @@ -318,14 +319,14 @@ async def _handle_tool_list_change_callback(self, server_name: str) -> None: except Exception as e: logger.error(f"Error in tool list changed callback: {e}") - # TODO -- decide whether to make this override type safe or not (modify SDK) async def call_tool( self, name: str, - arguments: dict | None = None, - _meta: dict | None = None, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, - **kwargs, + *, + meta: dict[str, Any] | None = None, ) -> CallToolResult: """Call a tool with optional metadata and progress callback support. @@ -336,31 +337,23 @@ async def call_tool( # Always create request ourselves to ensure we go through our send_request override # This is critical for session terminated detection to work - params = CallToolRequestParams(name=name, arguments=arguments) + _meta: RequestParams.Meta | None = None + if meta is not None: + _meta = RequestParams.Meta(**meta) - if _meta: - # Safe merge - preserve existing meta fields like progressToken - existing_meta = kwargs.get("meta") - if existing_meta: - meta_dict = ( - existing_meta.model_dump() if hasattr(existing_meta, "model_dump") else {} - ) - meta_dict.update(_meta) - meta_obj = RequestParams.Meta(**meta_dict) - else: - meta_obj = RequestParams.Meta(**_meta) - - params_dict = params.model_dump(by_alias=True) - params_dict["_meta"] = meta_obj.model_dump() - params = CallToolRequestParams.model_validate(params_dict) + # ty doesn't recognize _meta from pydantic alias - this matches SDK pattern + params = CallToolRequestParams(name=name, arguments=arguments, _meta=_meta) # ty: ignore[unknown-argument] request = CallToolRequest(method="tools/call", params=params) return await self.send_request( - request, CallToolResult, progress_callback=progress_callback + ClientRequest(request), + CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, ) async def read_resource( - self, uri: str, _meta: dict | None = None, **kwargs + self, uri: AnyUrl | str, _meta: dict | None = None, **kwargs ) -> ReadResourceResult: """Read a resource with optional metadata support. @@ -369,8 +362,11 @@ async def read_resource( """ from mcp.types import RequestParams + # Convert str to AnyUrl if needed + uri_obj: AnyUrl = uri if isinstance(uri, AnyUrl) else AnyUrl(uri) + # Always create request ourselves to ensure we go through our send_request override - params = ReadResourceRequestParams(uri=uri) + params = ReadResourceRequestParams(uri=uri_obj) if _meta: # Safe merge - preserve existing meta fields like progressToken @@ -383,10 +379,10 @@ async def read_resource( meta_obj = RequestParams.Meta(**meta_dict) else: meta_obj = RequestParams.Meta(**_meta) - params = ReadResourceRequestParams(uri=uri, meta=meta_obj) + params = ReadResourceRequestParams(uri=uri_obj, meta=meta_obj) request = ReadResourceRequest(method="resources/read", params=params) - return await self.send_request(request, ReadResourceResult) + return await self.send_request(ClientRequest(request), ReadResourceResult) async def get_prompt( self, name: str, arguments: dict | None = None, _meta: dict | None = None, **kwargs @@ -415,4 +411,4 @@ async def get_prompt( params = GetPromptRequestParams(name=name, arguments=arguments, meta=meta_obj) request = GetPromptRequest(method="prompts/get", params=params) - return await self.send_request(request, GetPromptResult) + return await self.send_request(ClientRequest(request), GetPromptResult) diff --git a/src/fast_agent/mcp/sampling.py b/src/fast_agent/mcp/sampling.py index b158202ca..3cf7cc068 100644 --- a/src/fast_agent/mcp/sampling.py +++ b/src/fast_agent/mcp/sampling.py @@ -2,9 +2,10 @@ This simplified implementation directly converts between MCP types and PromptMessageExtended. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from mcp import ClientSession +from mcp.shared.context import RequestContext from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent from fast_agent.agents.agent_types import AgentConfig @@ -60,7 +61,9 @@ def create_sampling_llm( return llm -async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) -> CreateMessageResult: +async def sample( + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams +) -> CreateMessageResult: """ Handle sampling requests from the MCP protocol using SamplingConverter. @@ -71,16 +74,14 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) -> 4. Returns the result as a CreateMessageResult Args: - mcp_ctx: The MCP ClientSession + context: The MCP RequestContext containing the ClientSession params: The sampling request parameters Returns: A CreateMessageResult containing the LLM's response """ # Get server name for notification tracking - server_name = "unknown" - if hasattr(mcp_ctx, "session") and hasattr(mcp_ctx.session, "session_server_name"): - server_name = mcp_ctx.session.session_server_name or "unknown" + server_name: str = getattr(context.session, "session_server_name", None) or "unknown" # Start tracking sampling operation try: @@ -94,7 +95,7 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) -> api_key: str | None = None try: # Extract model from server config using type-safe helper - server_config = get_server_config(mcp_ctx) + server_config = get_server_config(context) # First priority: explicitly configured sampling model if server_config and server_config.sampling: @@ -119,14 +120,12 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) -> from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession # Try agent's model first (from the session) - if hasattr(mcp_ctx, "session") and isinstance( - mcp_ctx.session, MCPAgentClientSession - ): - if mcp_ctx.session.agent_model: - model = mcp_ctx.session.agent_model + if isinstance(context.session, MCPAgentClientSession): + if context.session.agent_model: + model = context.session.agent_model logger.debug(f"Using agent's model for sampling: {model}") - if mcp_ctx.session.api_key: - api_key = mcp_ctx.session.api_key + if context.session.api_key: + api_key = context.session.api_key logger.debug(f"Using agent's API KEY for sampling: {api_key}") # Fall back to system default model From 6dac87273bb4f58c86d4441a6b672f92a2bcf63b Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:26:43 +0000 Subject: [PATCH 11/15] type checks --- src/fast_agent/agents/mcp_agent.py | 2 +- .../agents/workflow/agents_as_tools_agent.py | 12 +-- src/fast_agent/cli/commands/auth.py | 5 +- src/fast_agent/core/direct_decorators.py | 6 +- src/fast_agent/core/direct_factory.py | 4 +- src/fast_agent/core/executor/executor.py | 43 +++++---- .../core/executor/workflow_signal.py | 8 +- src/fast_agent/core/instruction.py | 9 +- src/fast_agent/core/logging/logger.py | 17 ++++ .../human_input/elicitation_handler.py | 5 +- src/fast_agent/llm/fastagent_llm.py | 12 ++- src/fast_agent/llm/internal/playback.py | 2 +- .../llm/provider/bedrock/llm_bedrock.py | 8 +- .../bedrock/multipart_converter_bedrock.py | 9 +- .../llm/provider/google/google_converter.py | 16 ++-- .../llm/provider/google/llm_google_native.py | 23 +++-- .../llm/provider/openai/llm_azure.py | 2 +- .../llm/provider/openai/llm_generic.py | 4 +- .../llm/provider/openai/llm_groq.py | 2 + .../llm/provider/openai/llm_openai.py | 72 ++++++++------- .../provider/openai/llm_tensorzero_openai.py | 35 +++++--- src/fast_agent/llm/provider/openai/llm_xai.py | 4 +- .../openai/multipart_converter_openai.py | 88 ++++++++++++------- .../llm/provider/openai/openai_multipart.py | 16 ++-- .../llm/provider/openai/openai_utils.py | 19 +++- src/fast_agent/llm/sampling_converter.py | 15 +++- src/fast_agent/mcp/logger_textio.py | 2 +- .../mcp/streamable_http_tracking.py | 2 +- src/fast_agent/skills/registry.py | 7 +- 29 files changed, 287 insertions(+), 162 deletions(-) diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index f61bf5265..ac6f72aff 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -1315,7 +1315,7 @@ def _shell_server_label(self) -> str | None: runtime_name = runtime_info.get("name") return runtime_name or "shell" - async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]: + async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str | None, str]: """Delegate resource name parsing to the aggregator.""" return await self._aggregator._parse_resource_name(name, resource_type) diff --git a/src/fast_agent/agents/workflow/agents_as_tools_agent.py b/src/fast_agent/agents/workflow/agents_as_tools_agent.py index 444a08fe3..6b3c2b538 100644 --- a/src/fast_agent/agents/workflow/agents_as_tools_agent.py +++ b/src/fast_agent/agents/workflow/agents_as_tools_agent.py @@ -191,7 +191,7 @@ async def coordinator(): pass from copy import copy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from mcp import ListToolsResult, Tool from mcp.types import CallToolResult @@ -638,10 +638,10 @@ async def _run_child_tools( call_descriptors: list[dict[str, Any]] = [] descriptor_by_id: dict[str, dict[str, Any]] = {} - tasks: list[asyncio.Task] = [] + tasks: list[asyncio.Task[CallToolResult]] = [] id_list: list[str] = [] - for correlation_id, tool_request in request.tool_calls.items(): + for correlation_id, tool_request in (request.tool_calls or {}).items(): if correlation_id not in target_ids: continue @@ -816,9 +816,11 @@ async def call_with_instance_name( descriptor_by_id[correlation_id]["status"] = "error" descriptor_by_id[correlation_id]["error_message"] = msg else: - tool_results[correlation_id] = result + # After exception check, result is CallToolResult + tool_result = cast("CallToolResult", result) + tool_results[correlation_id] = tool_result descriptor_by_id[correlation_id]["status"] = ( - "error" if result.isError else "done" + "error" if tool_result.isError else "done" ) ordered_records: list[dict[str, Any]] = [] diff --git a/src/fast_agent/cli/commands/auth.py b/src/fast_agent/cli/commands/auth.py index 00c5b4bf2..266b3fcba 100644 --- a/src/fast_agent/cli/commands/auth.py +++ b/src/fast_agent/cli/commands/auth.py @@ -333,9 +333,12 @@ def login( typer.echo("--transport must be 'http' or 'sse'") raise typer.Exit(1) endpoint = base + ("/mcp" if resolved_transport == "http" else "/sse") + # Cast transport after validation + from typing import Literal, cast + transport_type = cast("Literal['stdio', 'sse', 'http']", resolved_transport) cfg = MCPServerSettings( name=base, - transport=resolved_transport, + transport=transport_type, url=endpoint, auth=MCPServerAuthSettings(), ) diff --git a/src/fast_agent/core/direct_decorators.py b/src/fast_agent/core/direct_decorators.py index ccb4e9cd8..8c4bf6d10 100644 --- a/src/fast_agent/core/direct_decorators.py +++ b/src/fast_agent/core/direct_decorators.py @@ -223,9 +223,9 @@ def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutin name=name, instruction=instruction, servers=servers, - tools=tools, - resources=resources, - prompts=prompts, + tools=tools or {}, + resources=resources or {}, + prompts=prompts or {}, skills=skills, model=model, use_history=use_history, diff --git a/src/fast_agent/core/direct_factory.py b/src/fast_agent/core/direct_factory.py index 8c4d6b131..bd0175023 100644 --- a/src/fast_agent/core/direct_factory.py +++ b/src/fast_agent/core/direct_factory.py @@ -5,7 +5,7 @@ import os from functools import partial -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeVar, cast from fast_agent.agents import McpAgent from fast_agent.agents.agent_types import AgentConfig, AgentType @@ -230,7 +230,7 @@ async def create_agents_by_type( agent = AgentsAsToolsAgent( config=config, context=app_instance.context, - agents=child_agents, # expose children as tools + agents=cast("list[LlmAgent]", child_agents), # expose children as tools options=options, ) diff --git a/src/fast_agent/core/executor/executor.py b/src/fast_agent/core/executor/executor.py index 1ae5d1802..3dc6f03b9 100644 --- a/src/fast_agent/core/executor/executor.py +++ b/src/fast_agent/core/executor/executor.py @@ -10,9 +10,9 @@ AsyncIterator, Callable, Coroutine, - Type, TypeVar, Union, + cast, ) from pydantic import BaseModel, ConfigDict @@ -88,7 +88,7 @@ async def execute( @abstractmethod async def execute_streaming( self, - *tasks: list[Callable[..., R] | Coroutine[Any, Any, R]], + *tasks: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any, ) -> AsyncIterator[R | BaseException]: """Execute tasks and yield results as they complete""" @@ -102,7 +102,7 @@ async def map( """ Run `func(item)` for each item in `inputs` with concurrency limit. """ - results: list[R, BaseException] = [] + results: list[R | BaseException] = [] async def run(item): if self.config.max_concurrent_activities: @@ -134,16 +134,18 @@ async def validate_task(self, task: Callable[..., R] | Coroutine[Any, Any, R]) - async def signal( self, signal_name: str, - payload: SignalValueT = None, + payload: SignalValueT | None = None, signal_description: str | None = None, ) -> None: """ Emit a signal. """ - signal = Signal[SignalValueT]( + if self.signal_bus is None: + raise RuntimeError("No signal bus configured") + sig: Signal[SignalValueT] = Signal( name=signal_name, payload=payload, description=signal_description ) - await self.signal_bus.signal(signal) + await self.signal_bus.signal(sig) async def wait_for_signal( self, @@ -152,12 +154,14 @@ async def wait_for_signal( workflow_id: str | None = None, signal_description: str | None = None, timeout_seconds: int | None = None, - signal_type: Type[SignalValueT] = str, - ) -> SignalValueT: + signal_type: type[Any] | None = None, + ) -> Any: """ Wait until a signal with signal_name is emitted (or timeout). Return the signal's payload when triggered, or raise on timeout. """ + if self.signal_bus is None: + raise RuntimeError("No signal bus configured") # Notify any callbacks that the workflow is about to be paused waiting for a signal if self.context.signal_notification: @@ -168,14 +172,14 @@ async def wait_for_signal( metadata={ "description": signal_description, "timeout_seconds": timeout_seconds, - "signal_type": signal_type, + "signal_type": signal_type or str, }, ) - signal = Signal[signal_type]( + sig: Signal[Any] = Signal( name=signal_name, description=signal_description, workflow_id=workflow_id ) - return await self.signal_bus.wait_for_signal(signal) + return await self.signal_bus.wait_for_signal(sig) class AsyncioExecutor(Executor): @@ -196,10 +200,13 @@ def __init__( async def _execute_task( self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any ) -> R | BaseException: - async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R: + async def run_task( + task: Callable[..., R] | Coroutine[Any, Any, R], + ) -> R | BaseException: try: if asyncio.iscoroutine(task): - return await task + # iscoroutine doesn't narrow types, so cast the result + return cast("R", await task) elif asyncio.iscoroutinefunction(task): return await task(**kwargs) else: @@ -237,9 +244,9 @@ async def execute( return_exceptions=True, ) - async def execute_streaming( + async def execute_streaming( # ty: ignore[invalid-method-override] self, - *tasks: list[Callable[..., R] | Coroutine[Any, Any, R]], + *tasks: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any, ) -> AsyncIterator[R | BaseException]: # TODO: saqadri - validate if async with self.execution_context() is needed here @@ -256,7 +263,7 @@ async def execute_streaming( async def signal( self, signal_name: str, - payload: SignalValueT = None, + payload: SignalValueT | None = None, signal_description: str | None = None, ) -> None: await super().signal(signal_name, payload, signal_description) @@ -268,8 +275,8 @@ async def wait_for_signal( workflow_id: str | None = None, signal_description: str | None = None, timeout_seconds: int | None = None, - signal_type: Type[SignalValueT] = str, - ) -> SignalValueT: + signal_type: type[Any] | None = None, + ) -> Any: return await super().wait_for_signal( signal_name, request_id, diff --git a/src/fast_agent/core/executor/workflow_signal.py b/src/fast_agent/core/executor/workflow_signal.py index c7fb4cf14..c57dfa959 100644 --- a/src/fast_agent/core/executor/workflow_signal.py +++ b/src/fast_agent/core/executor/workflow_signal.py @@ -234,15 +234,17 @@ async def wait_for_signal(self, signal, timeout_seconds: int | None = None) -> S if not self._pending_signals[signal.name]: del self._pending_signals[signal.name] - def on_signal(self, signal_name): - def decorator(func): + def on_signal(self, signal_name: str) -> Callable: + def decorator(func: Callable) -> Callable: + unique_name = f"{signal_name}_{uuid.uuid4()}" + async def wrapped(value: SignalValueT) -> None: if asyncio.iscoroutinefunction(func): await func(value) else: func(value) - self._handlers.setdefault(signal_name, []).append(wrapped) + self._handlers.setdefault(signal_name, []).append((unique_name, wrapped)) return wrapped return decorator diff --git a/src/fast_agent/core/instruction.py b/src/fast_agent/core/instruction.py index 17b5b0f13..af266509a 100644 --- a/src/fast_agent/core/instruction.py +++ b/src/fast_agent/core/instruction.py @@ -39,8 +39,9 @@ logger = get_logger(__name__) -# Type alias for async resolvers -Resolver = Callable[[], Awaitable[str]] +# Type aliases +Resolver = Callable[[], Awaitable[str]] # Type alias for async resolvers +Set = set # Preserve built-in set type before method shadowing def _get_current_date() -> str: @@ -287,7 +288,7 @@ def replace_file(match: re.Match) -> str: # Utilities # ───────────────────────────────────────────────────────────────────────── - def get_placeholders(self) -> set[str]: + def get_placeholders(self) -> Set[str]: """ Extract all placeholder names from the template. @@ -298,7 +299,7 @@ def get_placeholders(self) -> set[str]: pattern = re.compile(r"\{\{(?!url:|file:|file_silent:)([^}]+)\}\}") return set(pattern.findall(self._template)) - def get_unresolved_placeholders(self) -> set[str]: + def get_unresolved_placeholders(self) -> Set[str]: """ Get placeholders that don't have a source registered. diff --git a/src/fast_agent/core/logging/logger.py b/src/fast_agent/core/logging/logger.py index c3a3d7853..73be23300 100644 --- a/src/fast_agent/core/logging/logger.py +++ b/src/fast_agent/core/logging/logger.py @@ -113,6 +113,23 @@ def error( """Log an error message.""" self.event("error", name, message, context, data) + def exception( + self, + message: str, + name: str | None = None, + context: EventContext | None = None, + **data, + ) -> None: + """Log an error message with exception info.""" + import sys + import traceback + + exc_info = sys.exc_info() + if exc_info[0] is not None: + tb_str = "".join(traceback.format_exception(*exc_info)) + data["exception"] = tb_str + self.event("error", name, message, context, data) + def progress( self, message: str, diff --git a/src/fast_agent/human_input/elicitation_handler.py b/src/fast_agent/human_input/elicitation_handler.py index 452b9568c..f677a4791 100644 --- a/src/fast_agent/human_input/elicitation_handler.py +++ b/src/fast_agent/human_input/elicitation_handler.py @@ -39,9 +39,10 @@ async def elicitation_input_callback( try: # Check if elicitation is disabled for this server + request_id = request.request_id or "" if elicitation_state.is_disabled(effective_server_name): return HumanInputResponse( - request_id=request.request_id, + request_id=request_id, response="__CANCELLED__", metadata={"auto_cancelled": True, "reason": "Server elicitation disabled by user"}, ) @@ -92,7 +93,7 @@ async def elicitation_input_callback( response = "__CANCELLED__" return HumanInputResponse( - request_id=request.request_id, + request_id=request_id, response=response.strip() if isinstance(response, str) else response, metadata={"has_schema": schema is not None}, ) diff --git a/src/fast_agent/llm/fastagent_llm.py b/src/fast_agent/llm/fastagent_llm.py index f8cd589cc..56f284a37 100644 --- a/src/fast_agent/llm/fastagent_llm.py +++ b/src/fast_agent/llm/fastagent_llm.py @@ -117,7 +117,7 @@ def __init__( context: Union["Context", None] = None, model: str | None = None, api_key: str | None = None, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> None: """ @@ -172,12 +172,12 @@ def __init__( self.retry_count = self._resolve_retry_count() self.retry_backoff_seconds: float = 10.0 - def _initialize_default_params(self, kwargs: dict) -> RequestParams: + def _initialize_default_params(self, kwargs: dict[str, Any]) -> RequestParams: """Initialize default parameters for the LLM. Should be overridden by provider implementations to set provider-specific defaults.""" # Get model-aware default max tokens model = kwargs.get("model") - max_tokens = ModelDatabase.get_default_max_tokens(model) + max_tokens = ModelDatabase.get_default_max_tokens(model) if model else 16384 return RequestParams( model=model, @@ -511,7 +511,7 @@ def _structured_from_multipart( text = self._prepare_structured_text(text) json_data = from_json(text, allow_partial=True) validated_model = model.model_validate(json_data) - return cast("ModelT", validated_model), message + return validated_model, message except ValueError as e: logger = get_logger(__name__) logger.warning(f"Failed to parse structured response: {str(e)}") @@ -892,6 +892,10 @@ def _api_key(self): def usage_accumulator(self): return self._usage_accumulator + @usage_accumulator.setter + def usage_accumulator(self, value): + self._usage_accumulator = value + def get_usage_summary(self) -> dict: """ Get a summary of usage statistics for this LLM instance. diff --git a/src/fast_agent/llm/internal/playback.py b/src/fast_agent/llm/internal/playback.py index 356bb0b38..510bfc46d 100644 --- a/src/fast_agent/llm/internal/playback.py +++ b/src/fast_agent/llm/internal/playback.py @@ -54,7 +54,7 @@ def _get_next_assistant_message(self) -> PromptMessageExtended: f"MESSAGES EXHAUSTED (list size {len(self._messages)}) ({self._overage} overage)" ) - async def generate( + async def generate( # type: ignore[override] self, messages: Union[ str, diff --git a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py index ab4b5749d..64ebee73e 100644 --- a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py @@ -43,10 +43,10 @@ NoCredentialsError, ) except ImportError: - boto3 = None - BotoCoreError = Exception - ClientError = Exception - NoCredentialsError = Exception + boto3 = None # type: ignore[assignment] + BotoCoreError = Exception # type: ignore[assignment, misc] + ClientError = Exception # type: ignore[assignment, misc] + NoCredentialsError = Exception # type: ignore[assignment, misc] DEFAULT_BEDROCK_MODEL = "amazon.nova-lite-v1:0" diff --git a/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py b/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py index 24c7b6179..e2df37f37 100644 --- a/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py @@ -24,7 +24,8 @@ def convert_to_bedrock(multipart_msg: PromptMessageExtended) -> BedrockMessagePa A Bedrock API message parameter dictionary """ # Simple conversion without needing BedrockLLM instance - bedrock_msg = {"role": multipart_msg.role, "content": []} + content_list: list[dict[str, Any]] = [] + bedrock_msg: BedrockMessageParam = {"role": multipart_msg.role, "content": content_list} # Handle tool results first (if present) if multipart_msg.tool_results: @@ -53,7 +54,7 @@ def convert_to_bedrock(multipart_msg: PromptMessageExtended) -> BedrockMessagePa if tool_result_parts: full_result_text = f"Tool Results:\n{', '.join(tool_result_parts)}" - bedrock_msg["content"].append({"type": "text", "text": full_result_text}) + content_list.append({"type": "text", "text": full_result_text}) else: # For Nova/Anthropic models: use structured tool_result format for tool_id, tool_result in multipart_msg.tool_results.items(): @@ -66,7 +67,7 @@ def convert_to_bedrock(multipart_msg: PromptMessageExtended) -> BedrockMessagePa if not result_content_blocks: result_content_blocks.append({"text": "[No content in tool result]"}) - bedrock_msg["content"].append( + content_list.append( { "type": "tool_result", "tool_use_id": tool_id, @@ -79,6 +80,6 @@ def convert_to_bedrock(multipart_msg: PromptMessageExtended) -> BedrockMessagePa from mcp.types import TextContent for content_item in multipart_msg.content: if isinstance(content_item, TextContent): - bedrock_msg["content"].append({"type": "text", "text": content_item.text}) + content_list.append({"type": "text", "text": content_item.text}) return bedrock_msg diff --git a/src/fast_agent/llm/provider/google/google_converter.py b/src/fast_agent/llm/provider/google/google_converter.py index 292d1e018..c5da75cc0 100644 --- a/src/fast_agent/llm/provider/google/google_converter.py +++ b/src/fast_agent/llm/provider/google/google_converter.py @@ -198,8 +198,10 @@ def convert_to_google_content( ) else: # Check if the resource itself has text content - # Use get_text helper to extract text from various content types - resource_text = get_text(part_content.resource) + # Try to get text from TextResourceContents directly + resource_text: str | None = None + if isinstance(part_content.resource, TextResourceContents): + resource_text = part_content.resource.text if resource_text is not None: parts.append(types.Part.from_text(text=resource_text)) @@ -273,7 +275,7 @@ def convert_from_google_content( elif part.function_call: fast_agent_parts.append( CallToolRequestParams( - name=part.function_call.name, + name=part.function_call.name or "unknown_function", arguments=part.function_call.args, ) ) @@ -288,7 +290,7 @@ def convert_from_google_function_call( return CallToolRequest( method="tools/call", params=CallToolRequestParams( - name=function_call.name, + name=function_call.name or "unknown_function", arguments=function_call.args, ), ) @@ -337,8 +339,10 @@ def convert_function_results_to_google( textual_outputs.append(f"[Error processing PDF from tool result: {e}]") else: # Check if the resource itself has text content - # Use get_text helper to extract text from various content types - resource_text = get_text(item.resource) + # Try to get text from TextResourceContents directly + resource_text: str | None = None + if isinstance(item.resource, TextResourceContents): + resource_text = item.resource.text if resource_text is not None: textual_outputs.append(resource_text) diff --git a/src/fast_agent/llm/provider/google/llm_google_native.py b/src/fast_agent/llm/provider/google/llm_google_native.py index 2a83bddb0..4acb07115 100644 --- a/src/fast_agent/llm/provider/google/llm_google_native.py +++ b/src/fast_agent/llm/provider/google/llm_google_native.py @@ -390,14 +390,16 @@ async def _google_completion( generate_content_config.response_schema = response_schema elif available_tools: # Tool calling enabled only when not doing structured output - generate_content_config.tools = available_tools + generate_content_config.tools = available_tools # type: ignore[assignment] generate_content_config.tool_config = types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="AUTO") + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.AUTO + ) ) # 3. Call the google.genai API client = self._initialize_google_client() - model_name = self._resolve_model_name(request_params.model) + model_name = self._resolve_model_name(request_params.model or DEFAULT_GOOGLE_MODEL) try: # Use the async client api_response = None @@ -454,17 +456,24 @@ async def _google_completion( if not api_response.candidates: # No response from the model, we're done self.logger.debug("No candidates returned.") + return Prompt.assistant(stop_reason=LlmStopReason.END_TURN) candidate = api_response.candidates[0] # Process the first candidate # Convert the model's response content to fast-agent types - model_response_content_parts = self._converter.convert_from_google_content( - candidate.content - ) + # Handle case where candidate.content might be None + candidate_content = candidate.content + if candidate_content is None: + model_response_content_parts: list[ContentBlock | CallToolRequestParams] = [] + else: + model_response_content_parts = self._converter.convert_from_google_content( + candidate_content + ) stop_reason = LlmStopReason.END_TURN tool_calls: dict[str, CallToolRequest] | None = None # Add model's response to the working conversation history for this turn - conversation_history.append(candidate.content) + if candidate_content is not None: + conversation_history.append(candidate_content) # Extract and process text content and tool calls assistant_message_parts = [] diff --git a/src/fast_agent/llm/provider/openai/llm_azure.py b/src/fast_agent/llm/provider/openai/llm_azure.py index 0ccf05c94..c123a4e2e 100644 --- a/src/fast_agent/llm/provider/openai/llm_azure.py +++ b/src/fast_agent/llm/provider/openai/llm_azure.py @@ -5,7 +5,7 @@ from fast_agent.llm.provider_types import Provider try: - from azure.identity import DefaultAzureCredential + from azure.identity import DefaultAzureCredential # ty: ignore[unresolved-import] except ImportError: DefaultAzureCredential = None diff --git a/src/fast_agent/llm/provider/openai/llm_generic.py b/src/fast_agent/llm/provider/openai/llm_generic.py index 95703adba..a2c4df623 100644 --- a/src/fast_agent/llm/provider/openai/llm_generic.py +++ b/src/fast_agent/llm/provider/openai/llm_generic.py @@ -26,8 +26,8 @@ def _initialize_default_params(self, kwargs: dict) -> RequestParams: use_history=True, ) - def _base_url(self) -> str: - base_url = os.getenv("GENERIC_BASE_URL", DEFAULT_OLLAMA_BASE_URL) + def _base_url(self) -> str | None: + base_url: str | None = os.getenv("GENERIC_BASE_URL", DEFAULT_OLLAMA_BASE_URL) if self.context.config and self.context.config.generic: base_url = self.context.config.generic.base_url diff --git a/src/fast_agent/llm/provider/openai/llm_groq.py b/src/fast_agent/llm/provider/openai/llm_groq.py index 9f43b06b6..0de8a7529 100644 --- a/src/fast_agent/llm/provider/openai/llm_groq.py +++ b/src/fast_agent/llm/provider/openai/llm_groq.py @@ -32,6 +32,8 @@ def _supports_structured_prompt(self) -> bool: llm_model = ( self.default_request_params.model if self.default_request_params else DEFAULT_GROQ_MODEL ) + if not llm_model: + return False json_mode: str | None = ModelDatabase.get_json_mode(llm_model) return json_mode == "object" diff --git a/src/fast_agent/llm/provider/openai/llm_openai.py b/src/fast_agent/llm/provider/openai/llm_openai.py index e84c78196..6b2d0b69b 100644 --- a/src/fast_agent/llm/provider/openai/llm_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_openai.py @@ -3,7 +3,7 @@ import os from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, cast from mcp import Tool from mcp.types import ( @@ -21,7 +21,9 @@ ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionToolParam, + ChatCompletionUserMessageParam, ) +from openai.types.chat.chat_completion_message_tool_call import Function from pydantic_core import from_json from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL, REASONING @@ -31,7 +33,7 @@ from fast_agent.event_progress import ProgressAction from fast_agent.llm.fastagent_llm import FastAgentLLM, RequestParams from fast_agent.llm.model_database import ModelDatabase -from fast_agent.llm.provider.openai.multipart_converter_openai import OpenAIConverter, OpenAIMessage +from fast_agent.llm.provider.openai.multipart_converter_openai import OpenAIConverter from fast_agent.llm.provider_types import Provider from fast_agent.llm.stream_types import StreamChunk from fast_agent.llm.usage_tracking import TurnUsage @@ -121,7 +123,7 @@ def __init__(self, provider: Provider = Provider.OPENAI, **kwargs) -> None: # Determine reasoning mode for the selected model chosen_model = self.default_request_params.model if self.default_request_params else None - self._reasoning_mode = ModelDatabase.get_reasoning(chosen_model) + self._reasoning_mode = ModelDatabase.get_reasoning(chosen_model) if chosen_model else None self._reasoning = self._reasoning_mode == "openai" if self._reasoning_mode: self.logger.info( @@ -140,8 +142,10 @@ def _initialize_default_params(self, kwargs: dict) -> RequestParams: return base_params - def _base_url(self) -> str: - return self.context.config.openai.base_url if self.context.config.openai else None + def _base_url(self) -> str | None: + if self.context.config and self.context.config.openai: + return self.context.config.openai.base_url + return None def _default_headers(self) -> dict[str, str] | None: """ @@ -761,10 +765,10 @@ async def _process_stream_manual( ChatCompletionMessageToolCall( id=tool_call_data["id"], type=tool_call_data["type"], - function={ - "name": tool_call_data["function"]["name"], - "arguments": tool_call_data["function"]["arguments"], - }, + function=Function( + name=tool_call_data["function"]["name"], + arguments=tool_call_data["function"]["arguments"], + ), ) ) @@ -819,7 +823,7 @@ async def _process_stream_manual( async def _openai_completion( self, - message: list[OpenAIMessage] | None, + message: list[ChatCompletionMessageParam] | None, request_params: RequestParams | None = None, tools: list[Tool] | None = None, ) -> PromptMessageExtended: @@ -842,19 +846,22 @@ async def _openai_completion( # The caller supplies the full history; convert it directly if message: - messages.extend(message) + messages.extend(cast("list[ChatCompletionMessageParam]", message)) - available_tools: list[ChatCompletionToolParam] | None = [ - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description if tool.description else "", - "parameters": self.adjust_schema(tool.inputSchema), - }, - } - for tool in tools or [] - ] + available_tools: list[ChatCompletionToolParam] | None = cast( + "list[ChatCompletionToolParam]", + [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description if tool.description else "", + "parameters": self.adjust_schema(tool.inputSchema), + }, + } + for tool in tools or [] + ], + ) if not available_tools: if self.provider in [Provider.DEEPSEEK, Provider.ALIYUN]: @@ -954,7 +961,7 @@ async def _openai_completion( message_dict["role"] = normalized_role or message_dict.get("role", "assistant") - messages.append(message_dict) + messages.append(cast("ChatCompletionMessageParam", message_dict)) stop_reason = LlmStopReason.END_TURN requested_tool_calls: dict[str, CallToolRequest] | None = None if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls: @@ -1078,13 +1085,13 @@ async def _apply_prompt_provider_specific( # Convert the supplied history/messages directly converted_messages = self._convert_to_provider_format(multipart_messages) if not converted_messages: - converted_messages = [{"role": "user", "content": ""}] + converted_messages = [ChatCompletionUserMessageParam(role="user", content="")] return await self._openai_completion(converted_messages, req_params, tools) def _prepare_api_request( - self, messages, tools: list[ChatCompletionToolParam] | None, request_params: RequestParams - ) -> dict[str, str]: + self, messages: list[ChatCompletionMessageParam], tools: list[ChatCompletionToolParam] | None, request_params: RequestParams + ) -> dict[str, Any]: # Create base arguments dictionary # overriding model via request params not supported (intentional) @@ -1170,7 +1177,8 @@ def _convert_extended_messages_to_provider( List of OpenAI ChatCompletionMessageParam objects """ converted: list[ChatCompletionMessageParam] = [] - reasoning_mode = ModelDatabase.get_reasoning(self.default_request_params.model) + model = self.default_request_params.model + reasoning_mode = ModelDatabase.get_reasoning(model) if model else None for msg in messages: # convert_to_openai returns a list of messages @@ -1184,7 +1192,8 @@ def _convert_extended_messages_to_provider( if reasoning_texts: reasoning_content = "\n\n".join(reasoning_texts) for oai_msg in openai_msgs: - oai_msg["reasoning_content"] = reasoning_content + # reasoning_content is an OpenAI extension not in the TypedDict + cast("dict[str, Any]", oai_msg)["reasoning_content"] = reasoning_content # gpt-oss: per docs, reasoning should be dropped on subsequent sampling # UNLESS tool calling is involved. For tool calls, prefix the assistant @@ -1197,8 +1206,11 @@ def _convert_extended_messages_to_provider( if reasoning_texts: reasoning_text = "\n\n".join(reasoning_texts) for oai_msg in openai_msgs: - existing_content = oai_msg.get("content", "") or "" - oai_msg["content"] = reasoning_text + existing_content + # Cast to dict to allow string concatenation with content + oai_dict = cast("dict[str, Any]", oai_msg) + existing_content = oai_dict.get("content", "") or "" + if isinstance(existing_content, str): + oai_dict["content"] = reasoning_text + existing_content converted.extend(openai_msgs) diff --git a/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py b/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py index e9074f208..f02ffcf83 100644 --- a/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast from openai.types.chat import ChatCompletionMessageParam, ChatCompletionSystemMessageParam @@ -82,15 +82,21 @@ def _prepare_api_request( self.logger.debug(f"Injecting template variables: {request_params.template_vars}") system_message_found = False for i, msg in enumerate(messages): - if msg.get("role") == "system": + # Work with msg as a dict for type safety + msg_dict = cast("dict[str, Any]", msg) + if msg_dict.get("role") == "system": + content = msg_dict.get("content") # If content is a string, convert it to the TensorZero format - if isinstance(msg.get("content"), str): - messages[i] = ChatCompletionSystemMessageParam( - role="system", content=[request_params.template_vars] + if isinstance(content, str): + # TensorZero expects content as list with template vars + messages[i] = cast( + "ChatCompletionSystemMessageParam", + {"role": "system", "content": [request_params.template_vars]}, ) - elif isinstance(msg.get("content"), list): + elif isinstance(content, list) and len(content) > 0: # If content is already a list, merge the template vars - msg["content"][0].update(request_params.template_vars) + if isinstance(content[0], dict): + content[0].update(request_params.template_vars) system_message_found = True break @@ -98,13 +104,15 @@ def _prepare_api_request( # If no system message exists, create one messages.insert( 0, - ChatCompletionSystemMessageParam( - role="system", content=[request_params.template_vars] + cast( + "ChatCompletionSystemMessageParam", + {"role": "system", "content": [request_params.template_vars]}, ), ) # Add TensorZero-specific extra body parameters - extra_body = arguments.get("extra_body", {}) + extra_body_raw = arguments.get("extra_body", {}) + extra_body: dict[str, Any] = extra_body_raw if isinstance(extra_body_raw, dict) else {} if self._t0_episode_id: extra_body["tensorzero::episode_id"] = str(self._t0_episode_id) @@ -116,8 +124,11 @@ def _prepare_api_request( if t0_args: self.logger.debug(f"Merging tensorzero_arguments from metadata: {t0_args}") for msg in messages: - if msg.get("role") == "system" and isinstance(msg.get("content"), list): - msg["content"][0].update(t0_args) + msg_dict = cast("dict[str, Any]", msg) + content = msg_dict.get("content") + if msg_dict.get("role") == "system" and isinstance(content, list) and len(content) > 0: + if isinstance(content[0], dict): + content[0].update(t0_args) break if extra_body: diff --git a/src/fast_agent/llm/provider/openai/llm_xai.py b/src/fast_agent/llm/provider/openai/llm_xai.py index 126a6915d..0b6579500 100644 --- a/src/fast_agent/llm/provider/openai/llm_xai.py +++ b/src/fast_agent/llm/provider/openai/llm_xai.py @@ -25,8 +25,8 @@ def _initialize_default_params(self, kwargs: dict) -> RequestParams: return base_params - def _base_url(self) -> str: - base_url = os.getenv("XAI_BASE_URL", XAI_BASE_URL) + def _base_url(self) -> str | None: + base_url: str | None = os.getenv("XAI_BASE_URL", XAI_BASE_URL) if self.context.config and self.context.config.xai: base_url = self.context.config.xai.base_url diff --git a/src/fast_agent/llm/provider/openai/multipart_converter_openai.py b/src/fast_agent/llm/provider/openai/multipart_converter_openai.py index 9913b9853..e2367745c 100644 --- a/src/fast_agent/llm/provider/openai/multipart_converter_openai.py +++ b/src/fast_agent/llm/provider/openai/multipart_converter_openai.py @@ -8,7 +8,13 @@ PromptMessage, TextContent, ) -from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, +) from fast_agent.core.logging.logger import get_logger from fast_agent.mcp.helpers.content_helpers import ( @@ -37,6 +43,21 @@ class OpenAIConverter: """Converts MCP message types to OpenAI API format.""" + @staticmethod + def _make_message(role: str, content: Any) -> ChatCompletionMessageParam: + """Create a properly typed message based on role.""" + if role == "assistant": + return ChatCompletionAssistantMessageParam(role="assistant", content=content) + elif role == "user": + return ChatCompletionUserMessageParam(role="user", content=content) + elif role == "tool": + # Tool messages need tool_call_id, but this helper is for simple content messages + # Tool messages are handled separately in convert_tool_result_to_openai + return ChatCompletionUserMessageParam(role="user", content=content) + else: + # Default to user for unknown roles (system messages handled elsewhere) + return ChatCompletionUserMessageParam(role="user", content=content) + @staticmethod def _is_supported_image_type(mime_type: str) -> bool: """ @@ -55,7 +76,7 @@ def _is_supported_image_type(mime_type: str) -> bool: @staticmethod def convert_to_openai( multipart_msg: PromptMessageExtended, concatenate_text_blocks: bool = False - ) -> list[dict[str, Any]]: + ) -> list[ChatCompletionMessageParam]: """ Convert a PromptMessageExtended message to OpenAI API format. @@ -70,7 +91,7 @@ def convert_to_openai( # assistant message with tool_calls per OpenAI format to establish the # required call IDs before tool responses appear. if multipart_msg.role == "assistant" and multipart_msg.tool_calls: - tool_calls_list: list[dict[str, Any]] = [] + tool_calls_list: list[ChatCompletionMessageToolCallParam] = [] for tool_id, req in multipart_msg.tool_calls.items(): name = None arguments = {} @@ -83,17 +104,17 @@ def convert_to_openai( pass tool_calls_list.append( - { - "id": tool_id, - "type": "function", - "function": { + ChatCompletionMessageToolCallParam( + id=tool_id, + type="function", + function={ "name": name or "unknown_tool", "arguments": json.dumps(arguments), }, - } + ) ) - return [{"role": "assistant", "tool_calls": tool_calls_list, "content": ""}] + return [ChatCompletionAssistantMessageParam(role="assistant", tool_calls=tool_calls_list, content="")] # Handle tool_results first if present if multipart_msg.tool_results: @@ -122,7 +143,7 @@ def convert_to_openai( @staticmethod def _convert_content_to_message( content: list, role: str, concatenate_text_blocks: bool = False - ) -> dict[str, Any] | None: + ) -> ChatCompletionMessageParam | None: """ Convert content blocks to a single OpenAI message. @@ -136,11 +157,11 @@ def _convert_content_to_message( """ # Handle empty content if not content: - return {"role": role, "content": ""} + return OpenAIConverter._make_message(role, "") # single text block if 1 == len(content) and is_text_content(content[0]): - return {"role": role, "content": get_text(content[0])} + return OpenAIConverter._make_message(role, get_text(content[0])) # For user messages, convert each content block content_blocks: list[ContentBlock] = [] @@ -183,16 +204,15 @@ def _convert_content_to_message( content_blocks.append({"type": "text", "text": fallback_text}) if not content_blocks: - return {"role": role, "content": ""} + return OpenAIConverter._make_message(role, "") # If concatenate_text_blocks is True, combine adjacent text blocks if concatenate_text_blocks: content_blocks = OpenAIConverter._concatenate_text_blocks(content_blocks) - # Return user message with content blocks - result = {"role": role, "content": content_blocks} + # Return message with content blocks _logger.debug(f"Final message for role '{role}': {len(content_blocks)} content blocks") - return result + return OpenAIConverter._make_message(role, content_blocks) @staticmethod def _concatenate_text_blocks(blocks: list[ContentBlock]) -> list[ContentBlock]: @@ -252,7 +272,7 @@ def convert_prompt_message_to_openai( # Use the existing conversion method with the specified concatenation option # Since convert_to_openai now returns a list, we return the first element messages = OpenAIConverter.convert_to_openai(multipart, concatenate_text_blocks) - return messages[0] if messages else {"role": message.role, "content": ""} + return messages[0] if messages else OpenAIConverter._make_message(message.role, "") @staticmethod def _convert_image_content(content: ImageContent) -> ContentBlock: @@ -423,7 +443,7 @@ def convert_tool_result_to_openai( tool_result: CallToolResult, tool_call_id: str, concatenate_text_blocks: bool = False, - ) -> Union[dict[str, Any], tuple[dict[str, Any], list[dict[str, Any]]]]: + ) -> Union[ChatCompletionMessageParam, tuple[ChatCompletionMessageParam, list[ChatCompletionMessageParam]]]: """ Convert a CallToolResult to an OpenAI tool message. @@ -441,11 +461,11 @@ def convert_tool_result_to_openai( """ # Handle empty content case if not tool_result.content: - return { - "role": "tool", - "tool_call_id": tool_call_id, - "content": "[Tool completed successfully]", - } + return ChatCompletionToolMessageParam( + role="tool", + tool_call_id=tool_call_id, + content="[Tool completed successfully]", + ) # Separate text and non-text content text_content = [] @@ -477,11 +497,11 @@ def convert_tool_result_to_openai( tool_message_content = "[Tool completed successfully]" # Create the tool message with just the text - tool_message = { - "role": "tool", - "tool_call_id": tool_call_id, - "content": tool_message_content, - } + tool_message = ChatCompletionToolMessageParam( + role="tool", + tool_call_id=tool_call_id, + content=tool_message_content, + ) # If there's no non-text content, return just the tool message if not non_text_content: @@ -509,7 +529,7 @@ def convert_tool_result_to_openai( def convert_function_results_to_openai( results: dict[str, CallToolResult], concatenate_text_blocks: bool = False, - ) -> list[dict[str, Any]]: + ) -> list[ChatCompletionMessageParam]: """ Convert function call results to OpenAI messages. @@ -544,11 +564,11 @@ def convert_function_results_to_openai( except Exception as e: _logger.error(f"Failed to convert tool_call_id={tool_call_id}: {e}") # Create a basic tool response to prevent missing tool_call_id error - fallback_message = { - "role": "tool", - "tool_call_id": tool_call_id, - "content": f"[Conversion error: {str(e)}]", - } + fallback_message = ChatCompletionToolMessageParam( + role="tool", + tool_call_id=tool_call_id, + content=f"[Conversion error: {str(e)}]", + ) tool_messages.append(fallback_message) # CONDITIONAL REORDERING: Only reorder if there are user messages (mixed content) diff --git a/src/fast_agent/llm/provider/openai/openai_multipart.py b/src/fast_agent/llm/provider/openai/openai_multipart.py index 03293e439..c21806319 100644 --- a/src/fast_agent/llm/provider/openai/openai_multipart.py +++ b/src/fast_agent/llm/provider/openai/openai_multipart.py @@ -25,7 +25,8 @@ def openai_to_extended( message: Union[ ChatCompletionMessage, ChatCompletionMessageParam, - list[Union[ChatCompletionMessage, ChatCompletionMessageParam]], + dict[str, Any], + list[Union[ChatCompletionMessage, ChatCompletionMessageParam, dict[str, Any]]], ], ) -> Union[PromptMessageExtended, list[PromptMessageExtended]]: """ @@ -43,16 +44,21 @@ def openai_to_extended( def _openai_message_to_extended( - message: Union[ChatCompletionMessage, dict[str, Any]], + message: Union[ChatCompletionMessage, ChatCompletionMessageParam, dict[str, Any]], ) -> PromptMessageExtended: """Convert a single OpenAI message to PromptMessageExtended.""" # Get role and content from message - if isinstance(message, dict): + # ChatCompletionMessage is a class with attributes; MessageParam types are TypedDicts + if isinstance(message, ChatCompletionMessage): + role = message.role + content = message.content + elif isinstance(message, dict): role = message.get("role", "assistant") content = message.get("content", "") else: - role = message.role - content = message.content + # Fallback for any other object with role/content attributes + role = getattr(message, "role", "assistant") + content = getattr(message, "content", "") mcp_contents = [] diff --git a/src/fast_agent/llm/provider/openai/openai_utils.py b/src/fast_agent/llm/provider/openai/openai_utils.py index 257d4eda0..e081a1df5 100644 --- a/src/fast_agent/llm/provider/openai/openai_utils.py +++ b/src/fast_agent/llm/provider/openai/openai_utils.py @@ -5,7 +5,7 @@ delegating to the proper implementations in the providers/ directory. """ -from typing import Any, Union +from typing import Any, Union, cast from openai.types.chat import ( ChatCompletionMessage, @@ -32,7 +32,11 @@ def openai_message_to_prompt_message_multipart( Returns: A PromptMessageExtended representation """ - return openai_to_extended(message) + result = openai_to_extended(message) + # Single message input always returns single message + if isinstance(result, list): + return result[0] if result else PromptMessageExtended(role="assistant", content=[]) + return result def openai_message_param_to_prompt_message_multipart( @@ -47,7 +51,11 @@ def openai_message_param_to_prompt_message_multipart( Returns: A PromptMessageExtended representation """ - return openai_to_extended(message_param) + result = openai_to_extended(message_param) + # Single message input always returns single message + if isinstance(result, list): + return result[0] if result else PromptMessageExtended(role="assistant", content=[]) + return result def prompt_message_multipart_to_openai_message_param( @@ -64,4 +72,7 @@ def prompt_message_multipart_to_openai_message_param( """ # convert_to_openai now returns a list, return the first element for backward compatibility messages = OpenAIConverter.convert_to_openai(multipart) - return messages[0] if messages else {"role": multipart.role, "content": ""} + if messages: + return cast("ChatCompletionMessageParam", messages[0]) + # Fallback for empty conversion + return cast("ChatCompletionMessageParam", {"role": multipart.role, "content": ""}) diff --git a/src/fast_agent/llm/sampling_converter.py b/src/fast_agent/llm/sampling_converter.py index 3c0671b16..ccfc9cca6 100644 --- a/src/fast_agent/llm/sampling_converter.py +++ b/src/fast_agent/llm/sampling_converter.py @@ -3,10 +3,14 @@ This replaces the more complex provider-specific converters with direct conversions. """ - from mcp.types import ( + AudioContent, + ContentBlock, CreateMessageRequestParams, CreateMessageResult, + EmbeddedResource, + ImageContent, + ResourceLink, SamplingMessage, TextContent, ) @@ -38,7 +42,14 @@ def sampling_message_to_prompt_message( Returns: PromptMessageExtended suitable for use with LLMs """ - return PromptMessageExtended(role=message.role, content=[message.content]) + # Filter content to only include supported types + supported_content: list[ContentBlock] = [] + content = message.content + if isinstance( + content, (TextContent, ImageContent, AudioContent, ResourceLink, EmbeddedResource) + ): + supported_content.append(content) + return PromptMessageExtended(role=message.role, content=supported_content) @staticmethod def extract_request_params(params: CreateMessageRequestParams) -> RequestParams: diff --git a/src/fast_agent/mcp/logger_textio.py b/src/fast_agent/mcp/logger_textio.py index 6fad73b06..2410c849b 100644 --- a/src/fast_agent/mcp/logger_textio.py +++ b/src/fast_agent/mcp/logger_textio.py @@ -28,7 +28,7 @@ def __init__(self, server_name: str) -> None: # Keep track of complete and partial lines self._line_buffer = "" - def write(self, s: str) -> int: + def write(self, s: str) -> int: # type: ignore[override] """ Write data to our buffer and log any complete lines. """ diff --git a/src/fast_agent/mcp/streamable_http_tracking.py b/src/fast_agent/mcp/streamable_http_tracking.py index 8cbd0a210..b09cf2522 100644 --- a/src/fast_agent/mcp/streamable_http_tracking.py +++ b/src/fast_agent/mcp/streamable_http_tracking.py @@ -276,7 +276,7 @@ async def _handle_sse_response( # type: ignore[override] if last_event_id is not None: # pragma: no branch await self._handle_reconnection(ctx, "post-sse", last_event_id, retry_interval_ms) - async def _handle_reconnection( + async def _handle_reconnection( # type: ignore[override] self, ctx: RequestContext, channel: ChannelName, diff --git a/src/fast_agent/skills/registry.py b/src/fast_agent/skills/registry.py index 5719df114..5e9039fc3 100644 --- a/src/fast_agent/skills/registry.py +++ b/src/fast_agent/skills/registry.py @@ -190,8 +190,9 @@ def _parse_manifest(cls, manifest_path: Path) -> tuple[SkillManifest | None, str allowed_tools = allowed_tools_raw.split() # Validate metadata is a dict if present - if custom_metadata is not None and not isinstance(custom_metadata, dict): - custom_metadata = None + typed_metadata: dict[str, str] | None = None + if isinstance(custom_metadata, dict): + typed_metadata = {str(k): str(v) for k, v in custom_metadata.items()} return SkillManifest( name=name.strip(), @@ -200,7 +201,7 @@ def _parse_manifest(cls, manifest_path: Path) -> tuple[SkillManifest | None, str path=manifest_path, license=license_field.strip() if isinstance(license_field, str) else None, compatibility=compatibility.strip() if isinstance(compatibility, str) else None, - metadata=custom_metadata, + metadata=typed_metadata, allowed_tools=allowed_tools, ), None From 63c9d1aedfe9447f3b8c44c58f388417f8a9d4e5 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:37:25 +0000 Subject: [PATCH 12/15] add ty to ci --- .github/workflows/checks.yml | 3 +++ .pre-commit-config.yaml | 7 +++++++ pyproject.toml | 1 + 3 files changed, 11 insertions(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 0a3460491..d0e1358f0 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -48,6 +48,9 @@ jobs: - name: Run pyright run: uv run scripts/lint.py + - name: Run ty + run: uv run scripts/typecheck.py + test: runs-on: ubuntu-latest steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e20ec38f4..f71dce348 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,3 +8,10 @@ repos: args: [--fix] # Run the formatter. - id: ruff-format + - repo: local + hooks: + - id: ty + name: ty check + entry: uv run scripts/typecheck.py + language: system + pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index 05878383f..e63cfb46d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ dev = [ "ruamel.yaml>=0.18.0", "pyyaml>=6.0.2", "ruff>=0.8.4", + "ty>=0.0.5", "pytest>=7.4.0", "pytest-asyncio>=0.21.1", "pytest-cov>=6.1.1", From eb3d5a98af267ee1309b32f1cfd352cbea47e6ec Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:37:58 +0000 Subject: [PATCH 13/15] typecheck --- scripts/typecheck.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 scripts/typecheck.py diff --git a/scripts/typecheck.py b/scripts/typecheck.py new file mode 100644 index 000000000..449b092b9 --- /dev/null +++ b/scripts/typecheck.py @@ -0,0 +1,35 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "typer", +# "ty", +# ] +# /// + +import subprocess +import sys + +import typer +from rich import print + + +def main(path: str = "src") -> None: + try: + command = ["ty", "check", path] + process = subprocess.run( + command, + check=True, + stdout=sys.stdout, + stderr=sys.stderr, + ) + sys.exit(process.returncode) + except subprocess.CalledProcessError as e: + print(f"Error: {e}") + sys.exit(e.returncode) + except FileNotFoundError: + print("Error: `ty` command not found. Make sure it's installed in the environment.") + sys.exit(1) + + +if __name__ == "__main__": + typer.run(main) From d28ba32dcaa4d5323a79e2d4d61bc94d365d46e0 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:40:20 +0000 Subject: [PATCH 14/15] add ty, typecheck, diff for toad --- .../src/hf_inference_acp/agents.py | 5 +- .../src/hf_inference_acp/cli.py | 3 +- .../src/hf_inference_acp/hf_config.py | 2 +- .../src/hf_inference_acp/wizard/wizard_llm.py | 2 + src/fast_agent/acp/filesystem_runtime.py | 36 +++++++++++++ src/fast_agent/acp/tool_permission_adapter.py | 10 ++++ src/fast_agent/acp/tool_permissions.py | 28 ++++++++++ src/fast_agent/cli/commands/url_parser.py | 11 ++-- src/fast_agent/core/agent_app.py | 11 ++-- src/fast_agent/core/logging/listeners.py | 2 + src/fast_agent/mcp/hf_auth.py | 52 ++++++++++++++----- src/fast_agent/ui/console_display.py | 5 -- src/fast_agent/ui/rich_progress.py | 1 - src/fast_agent/ui/streaming.py | 6 --- .../e2e/history/test_history_save_load_e2e.py | 2 +- tests/e2e/multimodal/image_server.py | 7 +-- tests/e2e/multimodal/video_server.py | 2 +- tests/e2e/smoke/tensorzero/test_image_demo.py | 1 + .../acp/test_acp_content_blocks.py | 12 ++--- .../acp/test_acp_runtime_telemetry.py | 11 ++-- .../acp/test_acp_tool_notifications.py | 11 ++-- .../commands/test_url_parser_hf_auth.py | 9 +++- tests/unit/fast_agent/mcp/test_hf_auth.py | 13 +++-- 23 files changed, 181 insertions(+), 61 deletions(-) diff --git a/publish/hf-inference-acp/src/hf_inference_acp/agents.py b/publish/hf-inference-acp/src/hf_inference_acp/agents.py index f076cd5d7..e1febad01 100644 --- a/publish/hf-inference-acp/src/hf_inference_acp/agents.py +++ b/publish/hf-inference-acp/src/hf_inference_acp/agents.py @@ -131,8 +131,9 @@ async def attach_llm(self, llm_factory, model=None, request_params=None, **kwarg llm = await super().attach_llm(llm_factory, model, request_params, **kwargs) # Set up wizard callback if LLM supports it - if hasattr(llm, "set_completion_callback"): - llm.set_completion_callback(self._on_wizard_complete) + callback_setter = getattr(llm, "set_completion_callback", None) + if callback_setter is not None: + callback_setter(self._on_wizard_complete) return llm diff --git a/publish/hf-inference-acp/src/hf_inference_acp/cli.py b/publish/hf-inference-acp/src/hf_inference_acp/cli.py index 50b3ed46f..d5dba4f84 100644 --- a/publish/hf-inference-acp/src/hf_inference_acp/cli.py +++ b/publish/hf-inference-acp/src/hf_inference_acp/cli.py @@ -11,6 +11,7 @@ import shlex import sys from pathlib import Path # noqa: TC003 - typer needs runtime access +from typing import Any, cast import typer @@ -220,7 +221,7 @@ async def run_agents( if skills_directory is not None: fast_kwargs["skills_directory"] = skills_directory - fast = FastAgent(**fast_kwargs) + fast = FastAgent(**cast("Any", fast_kwargs)) if shell_runtime: await fast.app.initialize() diff --git a/publish/hf-inference-acp/src/hf_inference_acp/hf_config.py b/publish/hf-inference-acp/src/hf_inference_acp/hf_config.py index 00115e5e7..9d41e6a54 100644 --- a/publish/hf-inference-acp/src/hf_inference_acp/hf_config.py +++ b/publish/hf-inference-acp/src/hf_inference_acp/hf_config.py @@ -54,7 +54,7 @@ def discover_hf_token(*, ignore_env: bool = False) -> tuple[str | None, str | No from huggingface_hub import get_token token = get_token() - return token, "huggingface_hub" if token else (None, None) + return (token, "huggingface_hub") if token else (None, None) except ImportError: pass diff --git a/publish/hf-inference-acp/src/hf_inference_acp/wizard/wizard_llm.py b/publish/hf-inference-acp/src/hf_inference_acp/wizard/wizard_llm.py index c97af7900..55b6eebab 100644 --- a/publish/hf-inference-acp/src/hf_inference_acp/wizard/wizard_llm.py +++ b/publish/hf-inference-acp/src/hf_inference_acp/wizard/wizard_llm.py @@ -414,6 +414,8 @@ async def _handle_confirm(self, user_input: str) -> str: elif cmd in ("y", "yes", "confirm", "ok", "save"): # Save configuration try: + if self._state.selected_model is None: + return "No model selected. Please select a model first." update_model_in_config(self._state.selected_model) update_mcp_server_load_on_start("huggingface", self._state.mcp_load_on_start) self._state.stage = WizardStage.COMPLETE diff --git a/src/fast_agent/acp/filesystem_runtime.py b/src/fast_agent/acp/filesystem_runtime.py index 21761749c..f3f6e5977 100644 --- a/src/fast_agent/acp/filesystem_runtime.py +++ b/src/fast_agent/acp/filesystem_runtime.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any +from acp.helpers import tool_diff_content +from acp.schema import ToolCallProgress from mcp.types import CallToolResult, Tool from fast_agent.core.logging.logger import get_logger @@ -312,6 +314,18 @@ async def write_text_file( content_length=len(content), ) + # Read existing file content for diff display (if file exists) + old_text: str | None = None + try: + response = await self.connection.read_text_file( + path=path, + session_id=self.session_id, + ) + old_text = response.content + except Exception: + # File doesn't exist or can't be read - that's fine, old_text stays None + pass + # Check permission before execution if self._permission_handler: try: @@ -320,6 +334,9 @@ async def write_text_file( server_name="acp_filesystem", arguments=arguments, tool_use_id=tool_use_id, + diff_old_text=old_text, + diff_new_text=content, + diff_path=path, ) if not permission_result.allowed: error_msg = permission_result.error_message or ( @@ -382,6 +399,25 @@ async def write_text_file( except Exception as e: self.logger.error(f"Error in tool complete handler: {e}", exc_info=True) + # Send diff content update for UI display + if tool_call_id: + try: + diff_content = tool_diff_content( + path=path, + new_text=content, + old_text=old_text, + ) + await self.connection.session_update( + session_id=self.session_id, + update=ToolCallProgress( + session_update="tool_call_update", + tool_call_id=tool_call_id, + content=[diff_content], + ), + ) + except Exception as e: + self.logger.error(f"Error sending diff content update: {e}", exc_info=True) + return result except Exception as e: diff --git a/src/fast_agent/acp/tool_permission_adapter.py b/src/fast_agent/acp/tool_permission_adapter.py index 17b9a358d..cd3d38660 100644 --- a/src/fast_agent/acp/tool_permission_adapter.py +++ b/src/fast_agent/acp/tool_permission_adapter.py @@ -64,6 +64,10 @@ async def check_permission( server_name: str, arguments: dict[str, Any] | None = None, tool_use_id: str | None = None, + *, + diff_old_text: str | None = None, + diff_new_text: str | None = None, + diff_path: str | None = None, ) -> ToolPermissionResult: """ Check if tool execution is permitted. @@ -76,6 +80,9 @@ async def check_permission( server_name: Name of the MCP server providing the tool arguments: Tool arguments tool_use_id: LLM's tool use ID + diff_old_text: Original file content for diff display (optional) + diff_new_text: New file content for diff display (optional) + diff_path: File path for diff display (optional) Returns: ToolPermissionResult indicating whether execution is allowed @@ -93,6 +100,9 @@ async def check_permission( server_name=server_name, arguments=arguments, tool_call_id=tool_call_id, + diff_old_text=diff_old_text, + diff_new_text=diff_new_text, + diff_path=diff_path, ) namespaced_tool_name = create_namespaced_name(server_name, tool_name) diff --git a/src/fast_agent/acp/tool_permissions.py b/src/fast_agent/acp/tool_permissions.py index c7957a34e..5d44c97e9 100644 --- a/src/fast_agent/acp/tool_permissions.py +++ b/src/fast_agent/acp/tool_permissions.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Protocol, runtime_checkable from acp.schema import ( + FileEditToolCallContent, PermissionOption, ToolCallProgress, ToolCallUpdate, @@ -161,6 +162,10 @@ async def check_permission( server_name: str, arguments: dict[str, Any] | None = None, tool_call_id: str | None = None, + *, + diff_old_text: str | None = None, + diff_new_text: str | None = None, + diff_path: str | None = None, ) -> PermissionResult: """ Check if tool execution is permitted. @@ -175,6 +180,9 @@ async def check_permission( server_name: Name of the MCP server providing the tool arguments: Tool arguments tool_call_id: Optional tool call ID for tracking + diff_old_text: Original file content for diff display (optional) + diff_new_text: New file content for diff display (optional) + diff_path: File path for diff display (optional) Returns: PermissionResult indicating whether execution is allowed @@ -212,6 +220,9 @@ async def check_permission( arguments=arguments, tool_call_id=tool_call_id, permission_key=permission_key, + diff_old_text=diff_old_text, + diff_new_text=diff_new_text, + diff_path=diff_path, ) except Exception as e: @@ -230,6 +241,9 @@ async def _request_permission_from_client( arguments: dict[str, Any] | None, tool_call_id: str | None, permission_key: str, + diff_old_text: str | None = None, + diff_new_text: str | None = None, + diff_path: str | None = None, ) -> PermissionResult: """ Request permission from the ACP client. @@ -273,12 +287,26 @@ async def _request_permission_from_client( # Create ToolCallUpdate object per ACP spec with raw_input for full argument visibility tool_kind = _infer_tool_kind(tool_name, arguments) + + # Build diff content if provided (for file edit operations) + content = None + if diff_new_text is not None and diff_path is not None: + content = [ + FileEditToolCallContent( + type="diff", + new_text=diff_new_text, + old_text=diff_old_text, + path=diff_path, + ) + ] + tool_call = ToolCallUpdate( tool_call_id=tool_call_id or "pending", title=title, kind=tool_kind, status="pending", raw_input=arguments, # Include full arguments so client can display them + content=content, ) # Create permission request with options diff --git a/src/fast_agent/cli/commands/url_parser.py b/src/fast_agent/cli/commands/url_parser.py index 337da45b9..c48c13e80 100644 --- a/src/fast_agent/cli/commands/url_parser.py +++ b/src/fast_agent/cli/commands/url_parser.py @@ -8,7 +8,7 @@ from typing import Literal from urllib.parse import urlparse -from fast_agent.mcp.hf_auth import add_hf_auth_header +from fast_agent.mcp.hf_auth import TokenProvider, add_hf_auth_header def parse_server_url( @@ -102,7 +102,9 @@ def generate_server_name(url: str) -> str: def parse_server_urls( - urls_param: str, auth_token: str | None = None + urls_param: str, + auth_token: str | None = None, + hub_token_provider: TokenProvider | None = None, ) -> list[tuple[str, Literal["http", "sse"], str, dict[str, str] | None]]: """ Parse a comma-separated list of URLs into server configurations. @@ -110,6 +112,9 @@ def parse_server_urls( Args: urls_param: Comma-separated list of URLs auth_token: Optional bearer token for authorization + hub_token_provider: Optional callable that returns a HuggingFace token. + Defaults to using huggingface_hub.get_token(). Pass a custom provider + for testing. Returns: List of tuples containing (server_name, transport_type, url, headers) @@ -134,7 +139,7 @@ def parse_server_urls( server_name, transport_type, parsed_url = parse_server_url(url) # Apply HuggingFace authentication if appropriate - final_headers = add_hf_auth_header(parsed_url, headers) + final_headers = add_hf_auth_header(parsed_url, headers, hub_token_provider) result.append((server_name, transport_type, parsed_url, final_headers)) diff --git a/src/fast_agent/core/agent_app.py b/src/fast_agent/core/agent_app.py index 52d2611fe..6f5ecdedf 100644 --- a/src/fast_agent/core/agent_app.py +++ b/src/fast_agent/core/agent_app.py @@ -9,6 +9,7 @@ from rich import print as rich_print from fast_agent.agents.agent_types import AgentType +from fast_agent.agents.workflow.parallel_agent import ParallelAgent from fast_agent.core.exceptions import AgentConfigError, ServerConfigError from fast_agent.interfaces import AgentProtocol from fast_agent.llm.usage_tracking import last_turn_usage @@ -369,12 +370,10 @@ def record(target: AgentProtocol) -> None: if accumulator is not None: indices[target.name] = len(accumulator.turns) - if agent.agent_type == AgentType.PARALLEL: - if getattr(agent, "fan_out_agents", None): - for child_agent in agent.fan_out_agents: - record(child_agent) - if getattr(agent, "fan_in_agent", None): - record(agent.fan_in_agent) + if isinstance(agent, ParallelAgent): + for child_agent in agent.fan_out_agents: + record(child_agent) + record(agent.fan_in_agent) else: record(agent) diff --git a/src/fast_agent/core/logging/listeners.py b/src/fast_agent/core/logging/listeners.py index ffb8e6544..45ad7ee14 100644 --- a/src/fast_agent/core/logging/listeners.py +++ b/src/fast_agent/core/logging/listeners.py @@ -280,6 +280,8 @@ async def stop(self) -> None: await self.flush() async def _periodic_flush(self) -> None: + if self._stop_event is None: + return try: while not self._stop_event.is_set(): try: diff --git a/src/fast_agent/mcp/hf_auth.py b/src/fast_agent/mcp/hf_auth.py index 6de95f860..5347e5a8a 100644 --- a/src/fast_agent/mcp/hf_auth.py +++ b/src/fast_agent/mcp/hf_auth.py @@ -1,8 +1,22 @@ """HuggingFace authentication utilities for MCP connections.""" import os +from collections.abc import Callable from urllib.parse import urlparse +# Type alias for token provider functions +TokenProvider = Callable[[], str | None] + + +def _default_hub_token_provider() -> str | None: + """Default token provider that uses huggingface_hub.get_token().""" + try: + from huggingface_hub import get_token # type: ignore + + return get_token() + except Exception: + return None + def is_huggingface_url(url: str) -> bool: """ @@ -47,13 +61,19 @@ def is_huggingface_url(url: str) -> bool: return False -def get_hf_token_from_env() -> str | None: +def get_hf_token_from_env( + hub_token_provider: TokenProvider | None = None, +) -> str | None: """ Get the HuggingFace token from the HF_TOKEN environment variable. Falls back to `huggingface_hub.get_token()` when available, so users who have authenticated via `hf auth login` don't need to manually export HF_TOKEN. + Args: + hub_token_provider: Optional callable that returns a token. Defaults to + using huggingface_hub.get_token(). Pass a custom provider for testing. + Returns: The HF_TOKEN value if set, None otherwise """ @@ -61,21 +81,23 @@ def get_hf_token_from_env() -> str | None: if token: return token - try: - from huggingface_hub import get_token # type: ignore - - return get_token() - except Exception: - return None + provider = hub_token_provider if hub_token_provider is not None else _default_hub_token_provider + return provider() -def should_add_hf_auth(url: str, existing_headers: dict[str, str] | None) -> bool: +def should_add_hf_auth( + url: str, + existing_headers: dict[str, str] | None, + hub_token_provider: TokenProvider | None = None, +) -> bool: """ Determine if HuggingFace authentication should be added to the headers. Args: url: The URL to check existing_headers: Existing headers dictionary (may be None) + hub_token_provider: Optional callable that returns a token. Defaults to + using huggingface_hub.get_token(). Pass a custom provider for testing. Returns: True if HF auth should be added, False otherwise @@ -106,24 +128,30 @@ def should_add_hf_auth(url: str, existing_headers: dict[str, str] | None) -> boo if "Authorization" in existing_headers: return False - return get_hf_token_from_env() is not None + return get_hf_token_from_env(hub_token_provider) is not None -def add_hf_auth_header(url: str, headers: dict[str, str] | None) -> dict[str, str] | None: +def add_hf_auth_header( + url: str, + headers: dict[str, str] | None, + hub_token_provider: TokenProvider | None = None, +) -> dict[str, str] | None: """ Add HuggingFace authentication header if appropriate. Args: url: The URL to check headers: Existing headers dictionary (may be None) + hub_token_provider: Optional callable that returns a token. Defaults to + using huggingface_hub.get_token(). Pass a custom provider for testing. Returns: Updated headers dictionary with HF auth if appropriate, or original headers """ - if not should_add_hf_auth(url, headers): + if not should_add_hf_auth(url, headers, hub_token_provider): return headers - hf_token = get_hf_token_from_env() + hf_token = get_hf_token_from_env(hub_token_provider) if hf_token is None: return headers diff --git a/src/fast_agent/ui/console_display.py b/src/fast_agent/ui/console_display.py index 77c70a697..c1d48fa06 100644 --- a/src/fast_agent/ui/console_display.py +++ b/src/fast_agent/ui/console_display.py @@ -778,15 +778,10 @@ def streaming_assistant_message( header_right=right_info, progress_display=progress_display, ) - # Pause progress display BEFORE yielding to prevent race condition with Anthropic - # (Anthropic's stream context manager may start events during __aenter__) - progress_display.pause() try: yield handle finally: handle.close() - # Resume progress display - must be explicit since we paused externally - progress_display.resume() def _display_mermaid_diagrams(self, diagrams: list[MermaidDiagram]) -> None: """Display mermaid diagram links.""" diff --git a/src/fast_agent/ui/rich_progress.py b/src/fast_agent/ui/rich_progress.py index a6e1ccbab..a176912d9 100644 --- a/src/fast_agent/ui/rich_progress.py +++ b/src/fast_agent/ui/rich_progress.py @@ -49,7 +49,6 @@ def pause(self) -> None: """Pause the progress display.""" if not self._paused: self._paused = True - for task in self._progress.tasks: task.visible = False self._progress.stop() diff --git a/src/fast_agent/ui/streaming.py b/src/fast_agent/ui/streaming.py index 430639398..81bdce68f 100644 --- a/src/fast_agent/ui/streaming.py +++ b/src/fast_agent/ui/streaming.py @@ -820,19 +820,13 @@ def handle_tool_event(self, event_type: str, info: dict[str, Any] | None = None) if event_type == "start": self._begin_tool_mode() - if not streams_arguments: - self._pause_progress_display() self.update(f"→ Calling {tool_name}\n") return if event_type == "delta": if streams_arguments and info and "chunk" in info: self.update(info["chunk"]) - elif event_type == "text": - self._pause_progress_display() elif event_type == "stop": self._end_tool_mode() - if not streams_arguments: - self._resume_progress_display() except Exception as exc: logger.warning( "Error handling tool event", diff --git a/tests/e2e/history/test_history_save_load_e2e.py b/tests/e2e/history/test_history_save_load_e2e.py index e017cd27a..ace542507 100644 --- a/tests/e2e/history/test_history_save_load_e2e.py +++ b/tests/e2e/history/test_history_save_load_e2e.py @@ -138,7 +138,7 @@ async def _get_or_create_history_file(create_model: str, tmp_path_factory) -> Pa async with agent_session(create_model, f"history-create-{create_model}") as creator_agent: await _create_history(creator_agent) - save_messages(creator_agent.message_history, history_file) + save_messages(creator_agent.message_history, str(history_file)) assert history_file.exists() _HISTORY_CACHE[create_model] = history_file diff --git a/tests/e2e/multimodal/image_server.py b/tests/e2e/multimodal/image_server.py index 7c883a6e6..36009666e 100644 --- a/tests/e2e/multimodal/image_server.py +++ b/tests/e2e/multimodal/image_server.py @@ -7,9 +7,10 @@ import logging import sys from pathlib import Path +from typing import cast from mcp.server.fastmcp import Context, FastMCP, Image -from mcp.types import BlobResourceContents, EmbeddedResource, ImageContent, TextContent +from mcp.types import AnyUrl, BlobResourceContents, EmbeddedResource, ImageContent, TextContent # Configure logging logging.basicConfig(level=logging.INFO) @@ -28,7 +29,7 @@ structured_output=False, ) async def get_image( - image_name: str = "default", ctx: Context = None + image_name: str = "default", ctx: Context | None = None ) -> list[TextContent | ImageContent]: try: # Use the global image path @@ -66,7 +67,7 @@ async def get_pdf() -> list[TextContent | EmbeddedResource]: EmbeddedResource( type="resource", resource=BlobResourceContents( - uri=f"file://{Path(pdf_path).absolute()}", + uri=cast("AnyUrl", f"file://{Path(pdf_path).absolute()}"), blob=b64_data, mimeType="application/pdf", ), diff --git a/tests/e2e/multimodal/video_server.py b/tests/e2e/multimodal/video_server.py index c0804e47a..2269bb872 100644 --- a/tests/e2e/multimodal/video_server.py +++ b/tests/e2e/multimodal/video_server.py @@ -30,7 +30,7 @@ async def get_video_link() -> list[TextContent | ResourceLink]: """Return a ResourceLink to a video.""" return [ - text_content(type="text", text="Here's a video link for analysis:"), + text_content("Here's a video link for analysis:"), video_link(video_url, name="Mystery Video"), ] diff --git a/tests/e2e/smoke/tensorzero/test_image_demo.py b/tests/e2e/smoke/tensorzero/test_image_demo.py index 1b7d81975..7ecc27972 100644 --- a/tests/e2e/smoke/tensorzero/test_image_demo.py +++ b/tests/e2e/smoke/tensorzero/test_image_demo.py @@ -39,6 +39,7 @@ async def test_tensorzero_image_demo_smoke(project_root, chdir_to_tensorzero_exa pytest.fail(f"'main' async function not found in {image_demo_script_path}") print("Executing image_demo.main()...") + assert main_func is not None # Type narrowing after pytest.fail await main_func() print("image_demo.main() executed successfully.") diff --git a/tests/integration/acp/test_acp_content_blocks.py b/tests/integration/acp/test_acp_content_blocks.py index 0aca1286d..0f2207ebe 100644 --- a/tests/integration/acp/test_acp_content_blocks.py +++ b/tests/integration/acp/test_acp_content_blocks.py @@ -95,7 +95,7 @@ async def test_acp_image_content_processing() -> None: ImageContentBlock( type="image", data=fake_image_data, - mimeType="image/png", + mime_type="image/png", ), ] @@ -160,7 +160,7 @@ async def test_acp_embedded_text_resource_processing() -> None: type="resource", resource=TextResourceContents( uri="file:///example.py", - mimeType="text/x-python", + mime_type="text/x-python", text="def hello():\n return 'Hello, world!'", ), ), @@ -219,7 +219,7 @@ async def test_acp_embedded_blob_resource_processing() -> None: type="resource", resource=BlobResourceContents( uri="file:///document.pdf", - mimeType="application/pdf", + mime_type="application/pdf", blob=fake_blob_data, ), ), @@ -271,7 +271,7 @@ async def test_acp_mixed_content_blocks() -> None: type="resource", resource=TextResourceContents( uri="file:///app.py", - mimeType="text/x-python", + mime_type="text/x-python", text="import sys\nprint(sys.version)", ), ), @@ -279,7 +279,7 @@ async def test_acp_mixed_content_blocks() -> None: ImageContentBlock( type="image", data=image_data, - mimeType="image/png", + mime_type="image/png", ), text_block("What's wrong?"), ] @@ -336,7 +336,7 @@ async def test_acp_resource_only_prompt_not_slash_command() -> None: type="resource", resource=TextResourceContents( uri="file:///C:/Users/shaun/AppData/Roaming/Zed/settings.json", - mimeType="application/json", + mime_type="application/json", text="//hello, world!", ), ), diff --git a/tests/integration/acp/test_acp_runtime_telemetry.py b/tests/integration/acp/test_acp_runtime_telemetry.py index c82f96508..da9f04338 100644 --- a/tests/integration/acp/test_acp_runtime_telemetry.py +++ b/tests/integration/acp/test_acp_runtime_telemetry.py @@ -10,6 +10,7 @@ import asyncio import sys from pathlib import Path +from typing import Any import pytest from acp.helpers import text_block @@ -34,11 +35,13 @@ def _get_stop_reason(response: object) -> str | None: return getattr(response, "stop_reason", None) or getattr(response, "stopReason", None) -def _get_session_update_type(update: object) -> str | None: - if hasattr(update, "sessionUpdate"): - return update.sessionUpdate +def _get_session_update_type(update: Any) -> str | None: + session_update = getattr(update, "sessionUpdate", None) + if session_update is not None: + return str(session_update) if isinstance(update, dict): - return update.get("sessionUpdate") + result = update.get("sessionUpdate") + return str(result) if result is not None else None return None diff --git a/tests/integration/acp/test_acp_tool_notifications.py b/tests/integration/acp/test_acp_tool_notifications.py index a2b3823b0..79dad9d11 100644 --- a/tests/integration/acp/test_acp_tool_notifications.py +++ b/tests/integration/acp/test_acp_tool_notifications.py @@ -10,6 +10,7 @@ import asyncio import sys from pathlib import Path +from typing import Any import pytest from acp.helpers import text_block @@ -36,11 +37,13 @@ def _get_stop_reason(response: object) -> str | None: return getattr(response, "stop_reason", None) or getattr(response, "stopReason", None) -def _get_session_update_type(update: object) -> str | None: - if hasattr(update, "sessionUpdate"): - return update.sessionUpdate +def _get_session_update_type(update: Any) -> str | None: + session_update = getattr(update, "sessionUpdate", None) + if session_update is not None: + return str(session_update) if isinstance(update, dict): - return update.get("sessionUpdate") + result = update.get("sessionUpdate") + return str(result) if result is not None else None return None FAST_AGENT_CMD = ( sys.executable, diff --git a/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py b/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py index ed26350cb..4da088284 100644 --- a/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py +++ b/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py @@ -9,6 +9,11 @@ from fast_agent.cli.commands.url_parser import parse_server_urls +def _no_hub_token() -> None: + """Token provider that always returns None (no huggingface_hub token).""" + return None + + def _set_hf_token(value: str | None) -> str | None: """Set HF_TOKEN environment variable and return the original value.""" original = os.getenv("HF_TOKEN") @@ -53,7 +58,9 @@ def test_hf_url_without_token_no_auth_header(self): """Test that HF URLs don't get auth headers when no token is available.""" original = _set_hf_token(None) try: - result = parse_server_urls("https://hf.co/models/gpt2") + result = parse_server_urls( + "https://hf.co/models/gpt2", hub_token_provider=_no_hub_token + ) assert len(result) == 1 server_name, transport_type, url, headers = result[0] diff --git a/tests/unit/fast_agent/mcp/test_hf_auth.py b/tests/unit/fast_agent/mcp/test_hf_auth.py index c2bf12051..6a4de05ba 100644 --- a/tests/unit/fast_agent/mcp/test_hf_auth.py +++ b/tests/unit/fast_agent/mcp/test_hf_auth.py @@ -14,6 +14,11 @@ ) +def _no_hub_token() -> None: + """Token provider that always returns None (no huggingface_hub token).""" + return None + + def _set_hf_token(value: str | None) -> str | None: """Set HF_TOKEN environment variable and return the original value.""" original = os.getenv("HF_TOKEN") @@ -105,14 +110,14 @@ def test_token_present(self): def test_token_absent(self): original = _set_hf_token(None) try: - assert get_hf_token_from_env() is None + assert get_hf_token_from_env(hub_token_provider=_no_hub_token) is None finally: _restore_hf_token(original) def test_token_empty_string(self): original = _set_hf_token("") try: - assert None is get_hf_token_from_env() + assert get_hf_token_from_env(hub_token_provider=_no_hub_token) is None finally: _restore_hf_token(original) @@ -130,7 +135,7 @@ def test_hf_url_no_existing_auth_with_token(self): def test_hf_url_no_existing_auth_no_token(self): original = _set_hf_token(None) try: - assert should_add_hf_auth("https://hf.co/models", None) is False + assert should_add_hf_auth("https://hf.co/models", None, _no_hub_token) is False finally: _restore_hf_token(original) @@ -277,7 +282,7 @@ def test_returns_none_when_no_headers_and_no_auth_needed(self): def test_returns_none_when_no_token_available(self): original = _set_hf_token(None) try: - result = add_hf_auth_header("https://hf.co/models", None) + result = add_hf_auth_header("https://hf.co/models", None, _no_hub_token) assert result is None finally: _restore_hf_token(original) From 7e9c258aab307a8b9bb21a61b2ff9dabcbe901c2 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:50:56 +0000 Subject: [PATCH 15/15] fix typecheck script --- scripts/typecheck.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/scripts/typecheck.py b/scripts/typecheck.py index 449b092b9..5dac38bfe 100644 --- a/scripts/typecheck.py +++ b/scripts/typecheck.py @@ -1,11 +1,3 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "typer", -# "ty", -# ] -# /// - import subprocess import sys