From 78f2f825cb5262f18722ae9c13d9594409521fb0 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sat, 6 Dec 2025 01:01:21 -0500 Subject: [PATCH 01/23] feat: add A2A service and utilities for tasks and push --- src/a2a_handler/a2a_service.py | 498 +++++++++++++++++++++++++++ src/a2a_handler/cli.py | 608 ++++++++++++++++++++++++++++----- src/a2a_handler/client.py | 114 +++---- src/a2a_handler/push_server.py | 166 +++++++++ src/a2a_handler/session.py | 165 +++++++++ 5 files changed, 1388 insertions(+), 163 deletions(-) create mode 100644 src/a2a_handler/a2a_service.py create mode 100644 src/a2a_handler/push_server.py create mode 100644 src/a2a_handler/session.py diff --git a/src/a2a_handler/a2a_service.py b/src/a2a_handler/a2a_service.py new file mode 100644 index 0000000..3bae18c --- /dev/null +++ b/src/a2a_handler/a2a_service.py @@ -0,0 +1,498 @@ +"""A2A protocol service layer. + +Provides a unified interface for A2A operations, reusable between CLI and TUI. +""" + +import uuid +from dataclasses import dataclass, field +from typing import Any, AsyncIterator + +import httpx +from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskState, + TaskStatusUpdateEvent, + TextPart, + TransportProtocol, +) + +from a2a_handler.common import get_logger + +log = get_logger(__name__) + +TERMINAL_STATES = { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, +} + + +@dataclass +class SendResult: + """Result from sending a message to an agent.""" + + task: Task | None = None + message: Message | None = None + context_id: str | None = None + task_id: str | None = None + state: TaskState | None = None + text: str = "" + raw: dict[str, Any] = field(default_factory=dict) + + @property + def is_complete(self) -> bool: + """Check if the task reached a terminal state.""" + return self.state in TERMINAL_STATES if self.state else False + + @property + def needs_input(self) -> bool: + """Check if the task is waiting for user input.""" + return self.state == TaskState.input_required if self.state else False + + +@dataclass +class StreamEvent: + """A single event from a streaming response.""" + + event_type: str # "task", "message", "status", "artifact" + task: Task | None = None + message: Message | None = None + status: TaskStatusUpdateEvent | None = None + artifact: TaskArtifactUpdateEvent | None = None + context_id: str | None = None + task_id: str | None = None + state: TaskState | None = None + text: str = "" + + +@dataclass +class TaskResult: + """Result from a task operation (get/cancel).""" + + task: Task + task_id: str + state: TaskState + context_id: str | None = None + text: str = "" + raw: dict[str, Any] = field(default_factory=dict) + + +def _extract_text_from_parts(parts: list[Part] | None) -> str: + """Extract text content from message parts.""" + if not parts: + return "" + texts = [] + for part in parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + texts.append(part.root.text) + elif hasattr(part, "text"): + texts.append(part.text) + return "\n".join(t for t in texts if t) + + +def _extract_text_from_task(task: Task) -> str: + """Extract text from task artifacts and history.""" + texts = [] + if task.artifacts: + for artifact in task.artifacts: + if artifact.parts: + texts.append(_extract_text_from_parts(artifact.parts)) + if task.history: + for msg in task.history: + if msg.role == Role.agent and msg.parts: + texts.append(_extract_text_from_parts(msg.parts)) + return "\n".join(t for t in texts if t) + + +class A2AService: + """High-level service for A2A protocol operations. + + Wraps the a2a-sdk Client and provides a simplified interface + for common operations. Designed to be shared between CLI and TUI. + """ + + def __init__( + self, + http_client: httpx.AsyncClient, + agent_url: str, + streaming: bool = True, + ): + """Initialize the A2A service. + + Args: + http_client: Async HTTP client to use for requests + agent_url: Base URL of the A2A agent + streaming: Whether to prefer streaming when available + """ + self.http_client = http_client + self.agent_url = agent_url + self.streaming = streaming + self._client: Client | None = None + self._card: AgentCard | None = None + + async def get_card(self) -> AgentCard: + """Fetch and cache the agent card. + + Returns: + The agent's card with metadata and capabilities + """ + if self._card is None: + log.info("Fetching agent card from %s", self.agent_url) + resolver = A2ACardResolver(self.http_client, self.agent_url) + self._card = await resolver.get_agent_card() + log.info("Connected to agent: %s", self._card.name) + return self._card + + async def _get_client(self) -> Client: + """Get or create the A2A client. + + Returns: + Configured A2A client instance + """ + if self._client is None: + card = await self.get_card() + config = ClientConfig( + httpx_client=self.http_client, + supported_transports=[TransportProtocol.jsonrpc], + streaming=self.streaming, + ) + factory = ClientFactory(config) + self._client = factory.create(card) + log.debug("Created A2A client for %s", card.name) + return self._client + + @property + def supports_streaming(self) -> bool: + """Check if the agent supports streaming.""" + if self._card and self._card.capabilities: + return bool(self._card.capabilities.streaming) + return False + + @property + def supports_push_notifications(self) -> bool: + """Check if the agent supports push notifications.""" + if self._card and self._card.capabilities: + return bool(self._card.capabilities.push_notifications) + return False + + def _build_message( + self, + text: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> Message: + """Build a user message. + + Args: + text: Message content + context_id: Optional context ID for conversation continuity + task_id: Optional task ID to continue + + Returns: + Properly formatted Message object + """ + return Message( + message_id=str(uuid.uuid4()), + role=Role.user, + parts=[Part(root=TextPart(text=text))], + context_id=context_id, + task_id=task_id, + ) + + async def send( + self, + text: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> SendResult: + """Send a message to the agent and wait for completion. + + This method collects all streaming events and returns the final result. + + Args: + text: Message to send + context_id: Optional context ID for conversation continuity + task_id: Optional task ID to continue + + Returns: + SendResult with task state, extracted text, and IDs + """ + client = await self._get_client() + message = self._build_message(text, context_id, task_id) + + log.info("Sending message: %s", text[:50] if len(text) > 50 else text) + + result = SendResult() + last_task: Task | None = None + + async for event in client.send_message(message): + if isinstance(event, Message): + result.message = event + result.context_id = event.context_id + result.task_id = event.task_id + result.text = _extract_text_from_parts(event.parts) + log.debug("Received message response") + elif isinstance(event, tuple): + task, update = event + last_task = task + result.task = task + result.task_id = task.id + result.context_id = task.context_id + if task.status: + result.state = task.status.state + log.debug( + "Received task update: %s", + task.status.state if task.status else "unknown", + ) + + if last_task: + result.text = _extract_text_from_task(last_task) + result.raw = ( + last_task.model_dump() if hasattr(last_task, "model_dump") else {} + ) + elif result.message: + result.raw = ( + result.message.model_dump() + if hasattr(result.message, "model_dump") + else {} + ) + + log.info("Send complete: task_id=%s, state=%s", result.task_id, result.state) + return result + + async def stream( + self, + text: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> AsyncIterator[StreamEvent]: + """Send a message and stream responses as they arrive. + + Args: + text: Message to send + context_id: Optional context ID for conversation continuity + task_id: Optional task ID to continue + + Yields: + StreamEvent objects as they are received + """ + client = await self._get_client() + message = self._build_message(text, context_id, task_id) + + log.info("Streaming message: %s", text[:50] if len(text) > 50 else text) + + async for event in client.send_message(message): + if isinstance(event, Message): + yield StreamEvent( + event_type="message", + message=event, + context_id=event.context_id, + task_id=event.task_id, + text=_extract_text_from_parts(event.parts), + ) + elif isinstance(event, tuple): + task, update = event + if isinstance(update, TaskStatusUpdateEvent): + status_text = "" + if update.status and update.status.message: + status_text = str(update.status.message) + yield StreamEvent( + event_type="status", + task=task, + status=update, + context_id=task.context_id, + task_id=task.id, + state=update.status.state if update.status else None, + text=status_text, + ) + elif isinstance(update, TaskArtifactUpdateEvent): + artifact_text = "" + if update.artifact and update.artifact.parts: + artifact_text = _extract_text_from_parts(update.artifact.parts) + yield StreamEvent( + event_type="artifact", + task=task, + artifact=update, + context_id=task.context_id, + task_id=task.id, + state=task.status.state if task.status else None, + text=artifact_text, + ) + else: + yield StreamEvent( + event_type="task", + task=task, + context_id=task.context_id, + task_id=task.id, + state=task.status.state if task.status else None, + text=_extract_text_from_task(task), + ) + + async def get_task( + self, + task_id: str, + history_length: int | None = None, + ) -> TaskResult: + """Get the current state of a task. + + Args: + task_id: ID of the task to retrieve + history_length: Optional number of history messages to include + + Returns: + TaskResult with task state and details + """ + client = await self._get_client() + + params = TaskQueryParams(id=task_id, history_length=history_length) + log.info("Getting task: %s", task_id) + + task = await client.get_task(params) + + return TaskResult( + task=task, + task_id=task.id, + state=task.status.state if task.status else TaskState.unknown, + context_id=task.context_id, + text=_extract_text_from_task(task), + raw=task.model_dump() if hasattr(task, "model_dump") else {}, + ) + + async def cancel_task(self, task_id: str) -> TaskResult: + """Cancel a running task. + + Args: + task_id: ID of the task to cancel + + Returns: + TaskResult with updated task state + """ + client = await self._get_client() + + params = TaskIdParams(id=task_id) + log.info("Canceling task: %s", task_id) + + task = await client.cancel_task(params) + + return TaskResult( + task=task, + task_id=task.id, + state=task.status.state if task.status else TaskState.unknown, + context_id=task.context_id, + text=_extract_text_from_task(task), + raw=task.model_dump() if hasattr(task, "model_dump") else {}, + ) + + async def resubscribe(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Resubscribe to a task's event stream. + + Args: + task_id: ID of the task to resubscribe to + + Yields: + StreamEvent objects as they are received + """ + client = await self._get_client() + + params = TaskIdParams(id=task_id) + log.info("Resubscribing to task: %s", task_id) + + async for event in client.resubscribe(params): + task, update = event + if isinstance(update, TaskStatusUpdateEvent): + yield StreamEvent( + event_type="status", + task=task, + status=update, + context_id=task.context_id, + task_id=task.id, + state=update.status.state if update.status else None, + ) + elif isinstance(update, TaskArtifactUpdateEvent): + artifact_text = "" + if update.artifact and update.artifact.parts: + artifact_text = _extract_text_from_parts(update.artifact.parts) + yield StreamEvent( + event_type="artifact", + task=task, + artifact=update, + context_id=task.context_id, + task_id=task.id, + state=task.status.state if task.status else None, + text=artifact_text, + ) + else: + yield StreamEvent( + event_type="task", + task=task, + context_id=task.context_id, + task_id=task.id, + state=task.status.state if task.status else None, + text=_extract_text_from_task(task), + ) + + async def set_push_config( + self, + task_id: str, + url: str, + token: str | None = None, + ) -> TaskPushNotificationConfig: + """Set push notification configuration for a task. + + Args: + task_id: ID of the task + url: Webhook URL to receive notifications + token: Optional authentication token + + Returns: + The created push notification configuration + """ + client = await self._get_client() + + from a2a.types import PushNotificationConfig + + config = TaskPushNotificationConfig( + task_id=task_id, + push_notification_config=PushNotificationConfig( + url=url, + token=token, + ), + ) + log.info("Setting push config for task %s: %s", task_id, url) + + return await client.set_task_callback(config) + + async def get_push_config( + self, + task_id: str, + config_id: str, + ) -> TaskPushNotificationConfig: + """Get push notification configuration for a task. + + Args: + task_id: ID of the task + config_id: ID of the push config + + Returns: + The push notification configuration + """ + client = await self._get_client() + + params = GetTaskPushNotificationConfigParams( + id=task_id, + push_notification_config_id=config_id, + ) + log.info("Getting push config %s for task %s", config_id, task_id) + + return await client.get_task_callback(params) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 1b3facc..ac8d5bc 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -35,8 +35,8 @@ ], "handler send": [ { - "name": "Conversation Options", - "options": ["--context-id", "--task-id"], + "name": "Message Options", + "options": ["--stream", "--continue", "--context-id", "--task-id"], }, { "name": "Output Options", @@ -49,20 +49,44 @@ "options": ["--host", "--port", "--help"], }, ], + "handler tasks get": [ + { + "name": "Query Options", + "options": ["--history-length", "--output", "--help"], + }, + ], } click.rich_click.COMMAND_GROUPS = { "handler": [ { "name": "Agent Commands", - "commands": ["card", "send", "validate"], + "commands": ["card", "send", "validate", "tasks", "push"], }, { "name": "Interface Commands", - "commands": ["tui", "server"], + "commands": ["tui", "server", "webhook"], }, { "name": "Utility Commands", - "commands": ["version"], + "commands": ["version", "session"], + }, + ], + "handler tasks": [ + { + "name": "Task Commands", + "commands": ["get", "cancel", "resubscribe"], + }, + ], + "handler push": [ + { + "name": "Push Notification Commands", + "commands": ["set", "get"], + }, + ], + "handler session": [ + { + "name": "Session Commands", + "commands": ["list", "show", "clear"], }, ], } @@ -75,13 +99,19 @@ A2AClientTimeoutError, ) +from a2a_handler.a2a_service import A2AService, SendResult, TaskResult # noqa: E402 from a2a_handler.client import ( # noqa: E402 build_http_client, fetch_agent_card, - parse_response, - send_message_to_agent, ) +from a2a_handler.push_server import run_webhook_server # noqa: E402 from a2a_handler.server import run_server # noqa: E402 +from a2a_handler.session import ( # noqa: E402 + clear_session, + get_session, + get_session_store, + update_session, +) from a2a_handler.tui import HandlerTUI # noqa: E402 from a2a_handler.validation import ( # noqa: E402 ValidationResult, @@ -92,6 +122,39 @@ log = get_logger(__name__) +def _handle_client_error(e: Exception, agent_url: str) -> None: + """Handle A2A client errors with appropriate messages.""" + if isinstance(e, A2AClientTimeoutError): + log.error("Request to %s timed out", agent_url) + print_error("Request timed out") + elif isinstance(e, A2AClientHTTPError): + log.error("A2A client error: %s", e) + if "connection" in str(e).lower(): + print_error(f"Connection failed: Is the server running at {agent_url}?") + else: + print_error(str(e)) + elif isinstance(e, A2AClientError): + log.error("A2A client error: %s", e) + print_error(str(e)) + elif isinstance(e, httpx.ConnectError): + log.error("Connection refused to %s", agent_url) + print_error(f"Connection refused: Is the server running at {agent_url}?") + elif isinstance(e, httpx.TimeoutException): + log.error("Request to %s timed out", agent_url) + print_error("Request timed out") + elif isinstance(e, httpx.HTTPStatusError): + log.error( + "HTTP error %d from %s: %s", + e.response.status_code, + agent_url, + e.response.text, + ) + print_error(f"HTTP {e.response.status_code} - {e.response.text}") + else: + log.exception("Failed request to %s", agent_url) + print_error(str(e)) + + @click.group() @click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging output") @click.option("--debug", "-d", is_flag=True, help="Enable debug logging output") @@ -185,6 +248,60 @@ def _format_value(value: Any, indent: int = 0) -> str: return str(value) if value else "" +def _format_send_result(result: SendResult, output: str) -> None: + """Format and display a send result.""" + if output == "json": + print_json(json.dumps(result.raw, indent=2)) + return + + content_parts = [] + + if result.context_id: + content_parts.append(f"[bold]Context ID:[/bold] [dim]{result.context_id}[/dim]") + if result.task_id: + content_parts.append(f"[bold]Task ID:[/bold] [dim]{result.task_id}[/dim]") + if result.state: + state_color = "green" if result.is_complete else "yellow" + content_parts.append( + f"[bold]State:[/bold] [{state_color}]{result.state.value}[/{state_color}]" + ) + + if content_parts: + console.print("\n".join(content_parts)) + console.print() + + if result.text: + print_markdown(result.text, title="Response") + else: + console.print("[dim]No text content in response[/dim]") + + +def _format_task_result(result: TaskResult, output: str) -> None: + """Format and display a task result.""" + if output == "json": + print_json(json.dumps(result.raw, indent=2)) + return + + state_color = "green" if result.state.value in ("completed",) else "yellow" + if result.state.value in ("failed", "rejected", "canceled"): + state_color = "red" + + content_parts = [ + f"[bold]Task ID:[/bold] [dim]{result.task_id}[/dim]", + f"[bold]State:[/bold] [{state_color}]{result.state.value}[/{state_color}]", + ] + + if result.context_id: + content_parts.append(f"[bold]Context ID:[/bold] [dim]{result.context_id}[/dim]") + + title = f"[bold]Task {result.task_id[:8]}...[/bold]" + print_panel("\n".join(content_parts), title=title) + + if result.text: + console.print() + print_markdown(result.text, title="Content") + + @cli.command() @click.argument("agent_url") @click.option( @@ -239,41 +356,8 @@ async def fetch() -> None: print_panel("\n\n".join(content_parts), title=title) - except A2AClientTimeoutError: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except A2AClientHTTPError as e: - log.error("A2A client error: %s", e) - if "connection" in str(e).lower(): - print_error(f"Connection failed: Is the server running at {agent_url}?") - else: - print_error(str(e)) - raise click.Abort() - except A2AClientError as e: - log.error("A2A client error: %s", e) - print_error(str(e)) - raise click.Abort() - except httpx.ConnectError: - log.error("Connection refused to %s", agent_url) - print_error(f"Connection refused: Is the server running at {agent_url}?") - raise click.Abort() - except httpx.TimeoutException: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except httpx.HTTPStatusError as e: - log.error( - "HTTP error %d from %s: %s", - e.response.status_code, - agent_url, - e.response.text, - ) - print_error(f"HTTP {e.response.status_code} - {e.response.text}") - raise click.Abort() except Exception as e: - log.exception("Failed to fetch agent card from %s", agent_url) - print_error(str(e)) + _handle_client_error(e, agent_url) raise click.Abort() asyncio.run(fetch()) @@ -282,8 +366,6 @@ async def fetch() -> None: def _format_validation_result(result: ValidationResult, output: str) -> None: """Format and print validation result.""" if output == "json": - import json - output_data = { "valid": result.valid, "source": result.source, @@ -377,8 +459,16 @@ async def do_validate() -> None: @cli.command() @click.argument("agent_url") @click.argument("message") +@click.option("--stream", "-s", is_flag=True, help="Stream responses in real-time") @click.option("--context-id", help="Context ID for conversation continuity") @click.option("--task-id", help="Reference an existing task ID") +@click.option( + "--continue", + "-c", + "use_session", + is_flag=True, + help="Continue last conversation (use saved context_id)", +) @click.option( "--output", "-o", @@ -389,14 +479,26 @@ async def do_validate() -> None: def send( agent_url: str, message: str, + stream: bool, context_id: Optional[str], task_id: Optional[str], + use_session: bool, output: str, ) -> None: - """Send MESSAGE to an agent at AGENT_URL.""" + """Send MESSAGE to an agent at AGENT_URL. + + Use --stream to receive responses in real-time via Server-Sent Events. + Use --continue to automatically use the last context_id from previous conversation. + """ log.info("Sending message to %s", agent_url) log.debug("Message: %s", message[:100] if len(message) > 100 else message) + if use_session and not context_id: + session = get_session(agent_url) + if session.context_id: + context_id = session.context_id + log.info("Using saved context ID: %s", context_id) + if context_id: log.debug("Using context ID: %s", context_id) if task_id: @@ -404,69 +506,389 @@ def send( async def send_msg() -> None: try: - log.debug("Building HTTP client") - async with build_http_client() as client: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url, streaming=stream) + if output == "text": console.print(f"[dim]Sending message to {agent_url}...[/dim]") - log.debug("Sending message via A2A client") - response = await send_message_to_agent( - agent_url, message, client, context_id, task_id - ) - log.debug("Received response from agent") + if stream: + log.debug("Using streaming mode") + collected_text: list[str] = [] + last_context_id: str | None = None + last_task_id: str | None = None + last_state = None + + async for event in service.stream(message, context_id, task_id): + last_context_id = event.context_id or last_context_id + last_task_id = event.task_id or last_task_id + last_state = event.state or last_state + + if output == "json": + event_data = { + "type": event.event_type, + "context_id": event.context_id, + "task_id": event.task_id, + "state": event.state.value if event.state else None, + "text": event.text, + } + print_json(json.dumps(event_data)) + else: + if event.text and event.text not in collected_text: + console.print(event.text, end="", markup=False) + collected_text.append(event.text) + + update_session(agent_url, last_context_id, last_task_id) + + if output == "text": + console.print() + console.print() + info_parts = [] + if last_context_id: + info_parts.append( + f"[bold]Context ID:[/bold] [dim]{last_context_id}[/dim]" + ) + if last_task_id: + info_parts.append( + f"[bold]Task ID:[/bold] [dim]{last_task_id}[/dim]" + ) + if last_state: + info_parts.append(f"[bold]State:[/bold] {last_state.value}") + if info_parts: + console.print("\n".join(info_parts)) - if output == "json": - log.debug("Outputting response as JSON") - print_json(json.dumps(response, indent=2)) else: - log.debug("Parsing response for text output") - parsed = parse_response(response) - - if parsed.has_content: - log.debug("Response contains %d characters", len(parsed.text)) - print_markdown(parsed.text, title="Response") - else: - log.warning("Response contained no text content") - print_markdown("No text in response", title="Response") + log.debug("Using non-streaming mode") + result = await service.send(message, context_id, task_id) + update_session(agent_url, result.context_id, result.task_id) + _format_send_result(result, output) - except A2AClientTimeoutError: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except A2AClientHTTPError as e: - log.error("A2A client error: %s", e) - if "connection" in str(e).lower(): - print_error(f"Connection failed: Is the server running at {agent_url}?") - else: - print_error(str(e)) + except Exception as e: + _handle_client_error(e, agent_url) raise click.Abort() - except A2AClientError as e: - log.error("A2A client error: %s", e) - print_error(str(e)) + + asyncio.run(send_msg()) + + +@cli.group() +def tasks() -> None: + """Manage A2A tasks.""" + pass + + +@tasks.command("get") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--history-length", + "-n", + type=int, + default=None, + help="Number of history messages to include", +) +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +def tasks_get( + agent_url: str, + task_id: str, + history_length: Optional[int], + output: str, +) -> None: + """Get the status of a task by TASK_ID.""" + log.info("Getting task %s from %s", task_id, agent_url) + + async def get_task() -> None: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + result = await service.get_task(task_id, history_length) + _format_task_result(result, output) + + except Exception as e: + _handle_client_error(e, agent_url) raise click.Abort() - except httpx.ConnectError: - log.error("Connection refused to %s", agent_url) - print_error(f"Connection refused: Is the server running at {agent_url}?") + + asyncio.run(get_task()) + + +@tasks.command("cancel") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +def tasks_cancel( + agent_url: str, + task_id: str, + output: str, +) -> None: + """Cancel a running task by TASK_ID.""" + log.info("Canceling task %s at %s", task_id, agent_url) + + async def cancel_task() -> None: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + + if output == "text": + console.print(f"[dim]Canceling task {task_id}...[/dim]") + + result = await service.cancel_task(task_id) + _format_task_result(result, output) + + if output == "text": + console.print("[green]Task canceled successfully[/green]") + + except Exception as e: + _handle_client_error(e, agent_url) raise click.Abort() - except httpx.TimeoutException: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") + + asyncio.run(cancel_task()) + + +@tasks.command("resubscribe") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +def tasks_resubscribe( + agent_url: str, + task_id: str, + output: str, +) -> None: + """Resubscribe to a task's event stream by TASK_ID. + + This resumes streaming for a task that you previously disconnected from. + """ + log.info("Resubscribing to task %s at %s", task_id, agent_url) + + async def resubscribe() -> None: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + + if output == "text": + console.print(f"[dim]Resubscribing to task {task_id}...[/dim]") + + async for event in service.resubscribe(task_id): + if output == "json": + event_data = { + "type": event.event_type, + "context_id": event.context_id, + "task_id": event.task_id, + "state": event.state.value if event.state else None, + "text": event.text, + } + print_json(json.dumps(event_data)) + else: + if event.event_type == "status": + console.print( + f"[dim]Status:[/dim] {event.state.value if event.state else 'unknown'}" + ) + elif event.text: + console.print(event.text, markup=False) + + except Exception as e: + _handle_client_error(e, agent_url) raise click.Abort() - except httpx.HTTPStatusError as e: - log.error( - "HTTP error %d from %s: %s", - e.response.status_code, - agent_url, - e.response.text, - ) - print_error(f"HTTP {e.response.status_code} - {e.response.text}") + + asyncio.run(resubscribe()) + + +@cli.group() +def push() -> None: + """Manage push notification configurations.""" + pass + + +@push.command("set") +@click.argument("agent_url") +@click.argument("task_id") +@click.option("--url", "-u", required=True, help="Webhook URL to receive notifications") +@click.option("--token", "-t", help="Optional authentication token") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +def push_set( + agent_url: str, + task_id: str, + url: str, + token: Optional[str], + output: str, +) -> None: + """Set push notification config for a task. + + Configure the agent to send push notifications to a webhook URL + when task status changes. + + Example: + handler push set http://localhost:8000 TASK_ID --url http://localhost:9000/webhook + """ + log.info("Setting push config for task %s at %s", task_id, agent_url) + + async def set_push() -> None: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + + if output == "text": + console.print( + f"[dim]Setting push notification config for task {task_id}...[/dim]" + ) + + config = await service.set_push_config(task_id, url, token) + + if output == "json": + print_json(config.model_dump_json(indent=2)) + else: + console.print( + "[green]Push notification config set successfully[/green]" + ) + console.print(f"[bold]Task ID:[/bold] {task_id}") + console.print(f"[bold]Webhook URL:[/bold] {url}") + if token: + console.print(f"[bold]Token:[/bold] {token[:20]}...") + + except Exception as e: + _handle_client_error(e, agent_url) raise click.Abort() + + asyncio.run(set_push()) + + +@push.command("get") +@click.argument("agent_url") +@click.argument("task_id") +@click.argument("config_id") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +def push_get( + agent_url: str, + task_id: str, + config_id: str, + output: str, +) -> None: + """Get push notification config for a task.""" + log.info("Getting push config %s for task %s at %s", config_id, task_id, agent_url) + + async def get_push() -> None: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + config = await service.get_push_config(task_id, config_id) + + if output == "json": + print_json(config.model_dump_json(indent=2)) + else: + console.print("[bold]Push Notification Config[/bold]") + console.print(f"[bold]Task ID:[/bold] {config.task_id}") + if config.push_notification_config: + pnc = config.push_notification_config + console.print(f"[bold]URL:[/bold] {pnc.url}") + if pnc.token: + console.print(f"[bold]Token:[/bold] {pnc.token[:20]}...") + except Exception as e: - log.exception("Failed to send message to %s", agent_url) - print_error(str(e)) + _handle_client_error(e, agent_url) raise click.Abort() - asyncio.run(send_msg()) + asyncio.run(get_push()) + + +@cli.command() +@click.option("--host", default="127.0.0.1", help="Host to bind to", show_default=True) +@click.option("--port", default=9000, help="Port to bind to", show_default=True) +def webhook(host: str, port: int) -> None: + """Start a local webhook server to receive push notifications. + + This starts a simple HTTP server that receives and displays + push notifications from A2A agents. Useful for testing. + + Example: + handler webhook --port 9000 + # Then use http://localhost:9000/webhook as your push notification URL + """ + log.info("Starting webhook server on %s:%d", host, port) + run_webhook_server(host, port) + + +@cli.group() +def session() -> None: + """Manage saved session state.""" + pass + + +@session.command("list") +def session_list() -> None: + """List all saved sessions.""" + store = get_session_store() + sessions = store.list_all() + + if not sessions: + console.print("[dim]No saved sessions[/dim]") + return + + console.print(f"[bold]Saved Sessions ({len(sessions)}):[/bold]") + console.print() + for s in sessions: + console.print(f"[bold cyan]{s.agent_url}[/bold cyan]") + if s.context_id: + console.print(f" [dim]Context ID:[/dim] {s.context_id}") + if s.task_id: + console.print(f" [dim]Task ID:[/dim] {s.task_id}") + + +@session.command("show") +@click.argument("agent_url") +def session_show(agent_url: str) -> None: + """Show session for a specific agent.""" + s = get_session(agent_url) + console.print(f"[bold]Session for {agent_url}[/bold]") + console.print(f"[bold]Context ID:[/bold] {s.context_id or '[dim]none[/dim]'}") + console.print(f"[bold]Task ID:[/bold] {s.task_id or '[dim]none[/dim]'}") + + +@session.command("clear") +@click.argument("agent_url", required=False) +@click.option("--all", "-a", "clear_all", is_flag=True, help="Clear all sessions") +def session_clear(agent_url: Optional[str], clear_all: bool) -> None: + """Clear saved session(s). + + Provide AGENT_URL to clear a specific session, or use --all to clear all. + """ + if clear_all: + clear_session() + console.print("[green]Cleared all sessions[/green]") + elif agent_url: + clear_session(agent_url) + console.print(f"[green]Cleared session for {agent_url}[/green]") + else: + console.print( + "[yellow]Provide AGENT_URL or use --all to clear sessions[/yellow]" + ) @cli.command() diff --git a/src/a2a_handler/client.py b/src/a2a_handler/client.py index 4fa9768..9786a54 100644 --- a/src/a2a_handler/client.py +++ b/src/a2a_handler/client.py @@ -1,20 +1,16 @@ -"""A2A protocol client utilities.""" +"""A2A protocol client utilities. + +This module provides backwards-compatible functions that wrap A2AService. +For new code, use A2AService directly. +""" -import uuid from dataclasses import dataclass from typing import Any import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import ( - AgentCard, - Message, - Part, - Role, - TextPart, - TransportProtocol, -) +from a2a.types import AgentCard +from a2a_handler.a2a_service import A2AService from a2a_handler.common import get_logger log = get_logger(__name__) @@ -47,35 +43,8 @@ async def fetch_agent_card(agent_url: str, client: httpx.AsyncClient) -> AgentCa Raises: httpx.RequestError: If the request fails """ - log.info("Fetching agent card from [url]%s[/url]", agent_url) - resolver = A2ACardResolver(client, agent_url) - card = await resolver.get_agent_card() - log.info("Received card for [agent]%s[/agent]", card.name) - return card - - -def _build_message( - message_text: str, - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Build a message object. - - Args: - message_text: The message content - context_id: Optional context ID for conversation continuity - task_id: Optional task ID to reference - - Returns: - A properly formatted message - """ - return Message( - message_id=str(uuid.uuid4()), - role=Role.user, - parts=[Part(TextPart(text=message_text))], - context_id=context_id, - task_id=task_id, - ) + service = A2AService(client, agent_url) + return await service.get_card() async def send_message_to_agent( @@ -101,33 +70,9 @@ async def send_message_to_agent( httpx.RequestError: If the request fails httpx.TimeoutException: If the request times out """ - log.info("Sending message to [url]%s[/url]: %s", agent_url, message_text[:50]) - card = await fetch_agent_card(agent_url, client) - log.debug("Connected to [agent]%s[/agent]", card.name) - - config = ClientConfig( - httpx_client=client, supported_transports=[TransportProtocol.jsonrpc] - ) - factory = ClientFactory(config) - a2a_client = factory.create(card) - - message = _build_message(message_text, context_id, task_id) - - log.debug("Sending request with ID: %s", message.message_id) - - last_response = None - async for response in a2a_client.send_message(message): - last_response = response - - log.debug("Received response") - - if last_response is None: - return {} - - if isinstance(last_response, tuple): - return last_response[0].model_dump() - - return last_response.model_dump() if hasattr(last_response, "model_dump") else {} + service = A2AService(client, agent_url) + result = await service.send(message_text, context_id, task_id) + return result.raw @dataclass @@ -136,6 +81,8 @@ class ParsedResponse: text: str raw: dict[str, Any] + context_id: str | None = None + task_id: str | None = None @property def has_content(self) -> bool: @@ -157,17 +104,44 @@ def parse_response(response: dict[str, Any]) -> ParsedResponse: return ParsedResponse(text="", raw=response) texts: list[str] = [] + context_id = response.get("context_id") + task_id = response.get("id") or response.get("task_id") if "parts" in response: - texts.extend(p.get("text", "") for p in response["parts"]) + for p in response["parts"]: + if isinstance(p, dict): + if "root" in p and isinstance(p["root"], dict): + texts.append(p["root"].get("text", "")) + else: + texts.append(p.get("text", "")) log.debug("Extracted %d parts from response", len(response["parts"])) for artifact in response.get("artifacts", []): artifact_parts = artifact.get("parts", []) - texts.extend(p.get("text", "") for p in artifact_parts) + for p in artifact_parts: + if isinstance(p, dict): + if "root" in p and isinstance(p["root"], dict): + texts.append(p["root"].get("text", "")) + else: + texts.append(p.get("text", "")) log.debug("Extracted %d parts from artifact", len(artifact_parts)) + if "history" in response: + for msg in response["history"]: + if msg.get("role") == "agent": + for p in msg.get("parts", []): + if isinstance(p, dict): + if "root" in p and isinstance(p["root"], dict): + texts.append(p["root"].get("text", "")) + else: + texts.append(p.get("text", "")) + text = "\n".join(t for t in texts if t) log.debug("Parsed response with %d characters", len(text)) - return ParsedResponse(text=text, raw=response) + return ParsedResponse( + text=text, + raw=response, + context_id=context_id, + task_id=task_id, + ) diff --git a/src/a2a_handler/push_server.py b/src/a2a_handler/push_server.py new file mode 100644 index 0000000..e80fe55 --- /dev/null +++ b/src/a2a_handler/push_server.py @@ -0,0 +1,166 @@ +"""Local webhook server for receiving A2A push notifications. + +This module provides a simple HTTP server that can receive push notifications +from A2A agents for testing purposes. +""" + +import json +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from a2a_handler.common import console, get_logger + +log = get_logger(__name__) + +MAX_NOTIFICATIONS = 100 + + +@dataclass +class Notification: + """A received push notification.""" + + timestamp: datetime + task_id: str | None + payload: dict[str, Any] + headers: dict[str, str] + + +@dataclass +class NotificationStore: + """In-memory store for received notifications.""" + + notifications: deque[Notification] = field( + default_factory=lambda: deque(maxlen=MAX_NOTIFICATIONS) + ) + + def add(self, notification: Notification) -> None: + """Add a notification to the store.""" + self.notifications.append(notification) + + def get_all(self) -> list[Notification]: + """Get all stored notifications.""" + return list(self.notifications) + + def clear(self) -> None: + """Clear all stored notifications.""" + self.notifications.clear() + + +notification_store = NotificationStore() + + +async def handle_notification(request: Request) -> JSONResponse: + """Handle incoming push notifications from A2A agents.""" + try: + payload = await request.json() + except json.JSONDecodeError: + log.warning("Received invalid JSON in push notification") + return JSONResponse({"error": "Invalid JSON"}, status_code=400) + + headers = dict(request.headers) + task_id = payload.get("id") or payload.get("task_id") + + notification = Notification( + timestamp=datetime.now(), + task_id=task_id, + payload=payload, + headers=headers, + ) + notification_store.add(notification) + + log.info("Received push notification for task: %s", task_id) + + console.print("\n[bold cyan]Push Notification Received[/bold cyan]") + console.print(f"[dim]Timestamp:[/dim] {notification.timestamp.isoformat()}") + if task_id: + console.print(f"[dim]Task ID:[/dim] {task_id}") + + status = payload.get("status", {}) + if status: + state = status.get("state", "unknown") + console.print(f"[dim]State:[/dim] {state}") + + token = headers.get("x-a2a-notification-token") + if token: + console.print(f"[dim]Token:[/dim] {token[:20]}...") + + console.print() + console.print_json(json.dumps(payload, indent=2, default=str)) + console.print() + + return JSONResponse({"status": "ok", "received": True}) + + +async def handle_validation(request: Request) -> JSONResponse: + """Handle GET requests for webhook validation.""" + log.info("Webhook validation request received") + return JSONResponse({"status": "ok", "message": "Webhook is active"}) + + +async def handle_list(request: Request) -> JSONResponse: + """List all received notifications.""" + notifications = notification_store.get_all() + return JSONResponse( + { + "count": len(notifications), + "notifications": [ + { + "timestamp": n.timestamp.isoformat(), + "task_id": n.task_id, + "payload": n.payload, + } + for n in notifications + ], + } + ) + + +async def handle_clear(request: Request) -> JSONResponse: + """Clear all stored notifications.""" + notification_store.clear() + log.info("Cleared all stored notifications") + return JSONResponse({"status": "ok", "message": "Notifications cleared"}) + + +def create_webhook_app() -> Starlette: + """Create the webhook Starlette application.""" + routes = [ + Route("/webhook", handle_notification, methods=["POST"]), + Route("/webhook", handle_validation, methods=["GET"]), + Route("/notifications", handle_list, methods=["GET"]), + Route("/notifications/clear", handle_clear, methods=["POST"]), + ] + return Starlette(routes=routes) + + +def run_webhook_server(host: str = "127.0.0.1", port: int = 9000) -> None: + """Start the webhook server. + + Args: + host: Host address to bind to + port: Port number to bind to + """ + console.print(f"\n[bold]Starting webhook server on [url]{host}:{port}[/url][/bold]") + console.print() + console.print("[dim]Endpoints:[/dim]") + console.print(f" POST http://{host}:{port}/webhook - Receive notifications") + console.print(f" GET http://{host}:{port}/webhook - Validation check") + console.print(f" GET http://{host}:{port}/notifications - List received") + console.print(f" POST http://{host}:{port}/notifications/clear - Clear stored") + console.print() + console.print( + f"[bold green]Use this URL for push notifications:[/bold green] " + f"http://{host}:{port}/webhook" + ) + console.print() + + app = create_webhook_app() + uvicorn.run(app, host=host, port=port, log_level="warning") diff --git a/src/a2a_handler/session.py b/src/a2a_handler/session.py new file mode 100644 index 0000000..ce89967 --- /dev/null +++ b/src/a2a_handler/session.py @@ -0,0 +1,165 @@ +"""Session state management for A2A CLI. + +Provides persistence of context_id and task_id across CLI invocations. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from a2a_handler.common import get_logger + +log = get_logger(__name__) + +DEFAULT_SESSION_DIR = Path.home() / ".handler" +SESSION_FILE = "sessions.json" + + +@dataclass +class AgentSession: + """Session state for a single agent.""" + + agent_url: str + context_id: str | None = None + task_id: str | None = None + + def update( + self, + context_id: str | None = None, + task_id: str | None = None, + ) -> None: + """Update session with new values (only if provided).""" + if context_id is not None: + self.context_id = context_id + if task_id is not None: + self.task_id = task_id + + +@dataclass +class SessionStore: + """Persistent store for agent sessions.""" + + sessions: dict[str, AgentSession] = field(default_factory=dict) + session_dir: Path = field(default_factory=lambda: DEFAULT_SESSION_DIR) + + @property + def session_file(self) -> Path: + """Path to the session file.""" + return self.session_dir / SESSION_FILE + + def _ensure_dir(self) -> None: + """Ensure the session directory exists.""" + self.session_dir.mkdir(parents=True, exist_ok=True) + + def load(self) -> None: + """Load sessions from disk.""" + if not self.session_file.exists(): + log.debug("No session file found at %s", self.session_file) + return + + try: + with open(self.session_file) as f: + data = json.load(f) + + for url, session_data in data.items(): + self.sessions[url] = AgentSession( + agent_url=url, + context_id=session_data.get("context_id"), + task_id=session_data.get("task_id"), + ) + log.debug( + "Loaded %d sessions from %s", len(self.sessions), self.session_file + ) + + except json.JSONDecodeError as e: + log.warning("Failed to parse session file: %s", e) + except OSError as e: + log.warning("Failed to read session file: %s", e) + + def save(self) -> None: + """Save sessions to disk.""" + self._ensure_dir() + + data: dict[str, Any] = {} + for url, session in self.sessions.items(): + data[url] = { + "context_id": session.context_id, + "task_id": session.task_id, + } + + try: + with open(self.session_file, "w") as f: + json.dump(data, f, indent=2) + log.debug("Saved %d sessions to %s", len(self.sessions), self.session_file) + except OSError as e: + log.warning("Failed to write session file: %s", e) + + def get(self, agent_url: str) -> AgentSession: + """Get or create a session for an agent URL.""" + if agent_url not in self.sessions: + self.sessions[agent_url] = AgentSession(agent_url=agent_url) + return self.sessions[agent_url] + + def update( + self, + agent_url: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> AgentSession: + """Update session for an agent and save.""" + session = self.get(agent_url) + session.update(context_id, task_id) + self.save() + return session + + def clear(self, agent_url: str | None = None) -> None: + """Clear session(s). + + Args: + agent_url: If provided, clear only that agent's session. + Otherwise, clear all sessions. + """ + if agent_url: + if agent_url in self.sessions: + del self.sessions[agent_url] + log.info("Cleared session for %s", agent_url) + else: + self.sessions.clear() + log.info("Cleared all sessions") + self.save() + + def list_all(self) -> list[AgentSession]: + """List all sessions.""" + return list(self.sessions.values()) + + +_store: SessionStore | None = None + + +def get_session_store() -> SessionStore: + """Get the global session store (singleton).""" + global _store + if _store is None: + _store = SessionStore() + _store.load() + return _store + + +def get_session(agent_url: str) -> AgentSession: + """Get session for an agent URL.""" + return get_session_store().get(agent_url) + + +def update_session( + agent_url: str, + context_id: str | None = None, + task_id: str | None = None, +) -> AgentSession: + """Update and persist session for an agent.""" + return get_session_store().update(agent_url, context_id, task_id) + + +def clear_session(agent_url: str | None = None) -> None: + """Clear session(s).""" + get_session_store().clear(agent_url) From 1a68a8757805229ae7281c9d89fa6ced25e745d7 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sat, 6 Dec 2025 17:20:13 -0500 Subject: [PATCH 02/23] feat: build and pass custom AgentCard to A2A server --- src/a2a_handler/server.py | 51 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/a2a_handler/server.py b/src/a2a_handler/server.py index 6e6e8ee..6cfbbbb 100644 --- a/src/a2a_handler/server.py +++ b/src/a2a_handler/server.py @@ -4,6 +4,7 @@ import click import uvicorn +from a2a.types import AgentCapabilities, AgentCard, AgentSkill from dotenv import load_dotenv from google.adk.a2a.utils.agent_to_a2a import to_a2a from google.adk.agents.llm_agent import Agent @@ -61,6 +62,44 @@ def create_agent() -> Agent: return agent +def build_agent_card(agent: Agent, host: str, port: int) -> AgentCard: + """Build an AgentCard with streaming and push notification capabilities. + + Args: + agent: The ADK agent + host: Host address for the RPC URL + port: Port number for the RPC URL + + Returns: + Configured AgentCard with capabilities enabled + """ + capabilities = AgentCapabilities( + streaming=True, + push_notifications=True, + ) + + skill = AgentSkill( + id="handler_assistant", + name="Handler Assistant", + description="Answers questions about the Handler A2A toolkit and helps with usage", + tags=["a2a", "handler", "help"], + examples=["What is Handler?", "How do I use the CLI?", "Tell me about A2A"], + ) + + rpc_url = f"http://{host}:{port}/" + + return AgentCard( + name=agent.name, + description=agent.description or "Handler A2A agent", + url=rpc_url, + version="1.0.0", + capabilities=capabilities, + skills=[skill], + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + ) + + def run_server(host: str, port: int) -> None: """Start the A2A server agent. @@ -73,7 +112,17 @@ def run_server(host: str, port: int) -> None: ) log.info("Initializing A2A server...") agent = create_agent() - a2a_app = to_a2a(agent, host=host, port=port) + + agent_card = build_agent_card(agent, host, port) + log.info( + "Agent card capabilities: streaming=%s, push_notifications=%s", + agent_card.capabilities.streaming if agent_card.capabilities else False, + agent_card.capabilities.push_notifications + if agent_card.capabilities + else False, + ) + + a2a_app = to_a2a(agent, host=host, port=port, agent_card=agent_card) uvicorn.run(a2a_app, host=host, port=port) From 17d442d1f1a764cfed702aca7a8b3a54e02d18cd Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sat, 6 Dec 2025 17:48:00 -0500 Subject: [PATCH 03/23] feat: add push notification support to A2A server --- src/a2a_handler/server.py | 99 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 4 deletions(-) diff --git a/src/a2a_handler/server.py b/src/a2a_handler/server.py index 6cfbbbb..b8a7177 100644 --- a/src/a2a_handler/server.py +++ b/src/a2a_handler/server.py @@ -1,14 +1,31 @@ -"""Handler A2A server agent.""" +"""Handler A2A server agent with full push notification support.""" import os +from collections.abc import Awaitable, Callable import click +import httpx import uvicorn +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import ( + InMemoryPushNotificationConfigStore, + InMemoryTaskStore, + BasePushNotificationSender, +) from a2a.types import AgentCapabilities, AgentCard, AgentSkill from dotenv import load_dotenv -from google.adk.a2a.utils.agent_to_a2a import to_a2a +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.agents.llm_agent import Agent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import ( + InMemoryCredentialService, +) +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from starlette.applications import Starlette from a2a_handler.common import console, get_logger, setup_logging @@ -100,6 +117,80 @@ def build_agent_card(agent: Agent, host: str, port: int) -> AgentCard: ) +def create_runner_factory(agent: Agent) -> Callable[[], Awaitable[Runner]]: + """Create a factory function that builds a Runner for the agent. + + Args: + agent: The ADK agent to wrap + + Returns: + A callable that creates a Runner instance + """ + + async def create_runner() -> Runner: + return Runner( + app_name=agent.name or "handler_agent", + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + credential_service=InMemoryCredentialService(), + ) + + return create_runner + + +def create_a2a_app(agent: Agent, agent_card: AgentCard) -> Starlette: + """Create a Starlette A2A application with full push notification support. + + This is a custom implementation that replaces google-adk's to_a2a() to add + push notification support. The to_a2a() function doesn't pass push_config_store + or push_sender to DefaultRequestHandler, causing push notification operations + to fail with "UnsupportedOperationError". + + Args: + agent: The ADK agent + agent_card: Pre-configured agent card + + Returns: + Configured Starlette application + """ + task_store = InMemoryTaskStore() + push_config_store = InMemoryPushNotificationConfigStore() + http_client = httpx.AsyncClient(timeout=30.0) + push_sender = BasePushNotificationSender(http_client, push_config_store) + + agent_executor = A2aAgentExecutor( + runner=create_runner_factory(agent), + ) + + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, + task_store=task_store, + push_config_store=push_config_store, + push_sender=push_sender, + ) + + app = Starlette() + + async def setup_a2a() -> None: + a2a_app = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, + ) + a2a_app.add_routes_to_app(app) + log.info("A2A routes configured with push notification support") + + async def cleanup() -> None: + await http_client.aclose() + log.info("HTTP client closed") + + app.add_event_handler("startup", setup_a2a) + app.add_event_handler("shutdown", cleanup) + + return app + + def run_server(host: str, port: int) -> None: """Start the A2A server agent. @@ -110,7 +201,7 @@ def run_server(host: str, port: int) -> None: console.print( f"\n[bold]Starting Handler server on [url]{host}:{port}[/url][/bold]\n" ) - log.info("Initializing A2A server...") + log.info("Initializing A2A server with push notification support...") agent = create_agent() agent_card = build_agent_card(agent, host, port) @@ -122,7 +213,7 @@ def run_server(host: str, port: int) -> None: else False, ) - a2a_app = to_a2a(agent, host=host, port=port, agent_card=agent_card) + a2a_app = create_a2a_app(agent, agent_card) uvicorn.run(a2a_app, host=host, port=port) From 941576e934c399ebc862b2db83bfbe97e67e1bcf Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sat, 6 Dec 2025 18:01:43 -0500 Subject: [PATCH 04/23] wip: push notification support --- src/a2a_handler/a2a_service.py | 17 +++++++++++++++++ src/a2a_handler/cli.py | 26 +++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/a2a_handler/a2a_service.py b/src/a2a_handler/a2a_service.py index 3bae18c..b85e885 100644 --- a/src/a2a_handler/a2a_service.py +++ b/src/a2a_handler/a2a_service.py @@ -14,6 +14,7 @@ GetTaskPushNotificationConfigParams, Message, Part, + PushNotificationConfig, Role, Task, TaskArtifactUpdateEvent, @@ -127,6 +128,8 @@ def __init__( http_client: httpx.AsyncClient, agent_url: str, streaming: bool = True, + push_notification_url: str | None = None, + push_notification_token: str | None = None, ): """Initialize the A2A service. @@ -134,10 +137,14 @@ def __init__( http_client: Async HTTP client to use for requests agent_url: Base URL of the A2A agent streaming: Whether to prefer streaming when available + push_notification_url: Optional webhook URL for push notifications + push_notification_token: Optional token for push notification auth """ self.http_client = http_client self.agent_url = agent_url self.streaming = streaming + self.push_notification_url = push_notification_url + self.push_notification_token = push_notification_token self._client: Client | None = None self._card: AgentCard | None = None @@ -162,10 +169,20 @@ async def _get_client(self) -> Client: """ if self._client is None: card = await self.get_card() + push_configs: list[PushNotificationConfig] = [] + if self.push_notification_url: + push_configs.append( + PushNotificationConfig( + url=self.push_notification_url, + token=self.push_notification_token, + ) + ) + log.info("Push notification config: %s", self.push_notification_url) config = ClientConfig( httpx_client=self.http_client, supported_transports=[TransportProtocol.jsonrpc], streaming=self.streaming, + push_notification_configs=push_configs, ) factory = ClientFactory(config) self._client = factory.create(card) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index ac8d5bc..3c1d8fd 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -38,6 +38,10 @@ "name": "Message Options", "options": ["--stream", "--continue", "--context-id", "--task-id"], }, + { + "name": "Push Notification Options", + "options": ["--push-url", "--push-token"], + }, { "name": "Output Options", "options": ["--output", "--help"], @@ -469,6 +473,15 @@ async def do_validate() -> None: is_flag=True, help="Continue last conversation (use saved context_id)", ) +@click.option( + "--push-url", + "-p", + help="Webhook URL to receive push notifications for this task", +) +@click.option( + "--push-token", + help="Optional authentication token for push notifications", +) @click.option( "--output", "-o", @@ -483,12 +496,15 @@ def send( context_id: Optional[str], task_id: Optional[str], use_session: bool, + push_url: Optional[str], + push_token: Optional[str], output: str, ) -> None: """Send MESSAGE to an agent at AGENT_URL. Use --stream to receive responses in real-time via Server-Sent Events. Use --continue to automatically use the last context_id from previous conversation. + Use --push-url to configure push notifications for task updates. """ log.info("Sending message to %s", agent_url) log.debug("Message: %s", message[:100] if len(message) > 100 else message) @@ -507,10 +523,18 @@ def send( async def send_msg() -> None: try: async with build_http_client() as http_client: - service = A2AService(http_client, agent_url, streaming=stream) + service = A2AService( + http_client, + agent_url, + streaming=stream, + push_notification_url=push_url, + push_notification_token=push_token, + ) if output == "text": console.print(f"[dim]Sending message to {agent_url}...[/dim]") + if push_url: + console.print(f"[dim]Push notifications: {push_url}[/dim]") if stream: log.debug("Using streaming mode") From a99dabd792f363e3654f77af52a4a66b72d13cb8 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sun, 7 Dec 2025 19:08:06 -0500 Subject: [PATCH 05/23] docs: higher level to avoid slippage --- AGENTS.md | 76 +++++++++++++------------------------------------ CONTRIBUTING.md | 65 +++++------------------------------------- README.md | 64 ++--------------------------------------- 3 files changed, 29 insertions(+), 176 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 1ce6b2d..3e772fc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,70 +1,32 @@ # Agent Development Guide -## Commands -Use `just` for all development tasks: +## Quick Start -| Command | Description | -|---------|-------------| -| `just install` | Install dependencies | -| `just check` | Run lint, format, and typecheck | -| `just fix` | Auto-fix lint/format issues | -| `just test` | Run pytest test suite | -| `just server` | Start A2A server (port 8000) | -| `just tui` | Run TUI application | -| `just tui-dev` | Run TUI with Textual devtools | -| `just web` | Serve TUI as web app | -| `just console` | Run Textual devtools console | -| `just get-card` | Fetch agent card (CLI) | -| `just send` | Send message to agent (CLI) | -| `just validate` | Validate agent card (CLI) | -| `just version` | Show current version | -| `just bump` | Bump version (patch, minor, major) | -| `just tag` | Create git tag for current version | -| `just release` | Tag and push release to origin | +```bash +just install # Install dependencies +just check # Run lint, format, and typecheck +just test # Run tests +``` + +Run `just` to see all available commands. ## Project Structure ``` -handler/ -├── src/a2a_handler/ # Main package -│ ├── _version.py # Version string -│ ├── cli.py # CLI (rich-click) -│ ├── client.py # A2A protocol client (a2a-sdk) -│ ├── validation.py # Agent card validation utilities -│ ├── server.py # A2A server agent (google-adk, litellm) -│ ├── tui.py # TUI application (textual) -│ ├── common/ # Shared utilities (rich, logging) -│ │ ├── logging.py -│ │ └── printing.py -│ └── components/ # TUI components -└── tests/ # pytest tests +src/a2a_handler/ +├── cli.py # CLI entry point +├── client.py # A2A protocol client +├── validation.py # Agent card validation +├── server.py # A2A server agent +├── tui.py # TUI application +├── common/ # Shared utilities +└── components/ # TUI components ``` -## Code Style & Conventions +## Code Style - **Python 3.11+** with full type hints -- **Formatting**: `ruff format` (black compatible) +- **Formatting**: `ruff format` - **Linting**: `ruff check` - **Type Checking**: `ty check` -- **Imports**: Standard → Third-party → Local -- **Testing**: pytest with pytest-asyncio for async tests - -## Environment Variables - -- `OLLAMA_API_BASE`: Ollama server URL (default: `http://localhost:11434`) -- `OLLAMA_MODEL`: Model to use (default: `qwen3`) - -## A2A Protocol - -The `a2a_handler.client` module provides A2A protocol logic: -- `build_http_client()` - Create configured HTTP client -- `fetch_agent_card()` - Retrieve agent metadata -- `send_message_to_agent()` - Send messages and get responses - -## Key Dependencies - -- **CLI**: `rich-click` (enhanced `click` with rich formatting) -- **Client**: `a2a-sdk`, `httpx` -- **Server**: `google-adk`, `litellm`, `uvicorn` -- **TUI**: `textual` -- **Common**: `rich` +- **Testing**: pytest with pytest-asyncio diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d9c1d34..7555103 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,78 +1,27 @@ # Contributing to Handler -## Architecture - -Handler is a single Python package (`a2a-handler`) with all modules under `src/a2a_handler/`: - -| Module | Description | -|--------|-------------| -| `cli.py` | CLI built with `rich-click`. Entry point: `handler` | -| `client.py` | A2A protocol client library using `a2a-sdk` | -| `validation.py` | Agent card validation utilities | -| `common/` | Shared utilities (logging, printing with `rich`) | -| `server.py` | Reference A2A agent using `google-adk` + `litellm` | -| `tui.py` | TUI application built with `textual` | -| `components/` | TUI components | - ## Prerequisites - **Python 3.11+** - **[uv](https://github.com/astral-sh/uv)** for dependency management -- **[just](https://github.com/casey/just)** for running commands (recommended) -- **[Ollama](https://ollama.com/)** for running the reference server agent +- **[just](https://github.com/casey/just)** for running commands +- **[Ollama](https://ollama.com/)** for running the server agent ## Setup ```bash git clone https://github.com/alDuncanson/handler.git cd handler -just install # or: uv sync +just install ``` -## Development Commands +## Development -| Command | Description | -|---------|-------------| -| `just install` | Install dependencies | -| `just check` | Run lint, format, and typecheck | -| `just fix` | Auto-fix lint/format issues | -| `just test` | Run pytest test suite | -| `just server` | Start A2A server on port 8000 | -| `just tui` | Run TUI application | -| `just tui-dev` | Run TUI with Textual devtools | -| `just web` | Serve TUI as web app | -| `just console` | Run Textual devtools console | -| `just get-card [url]` | Fetch agent card from URL | -| `just send [url] [msg]` | Send message to agent | -| `just validate [source]` | Validate agent card from URL or file | -| `just version` | Show current version | -| `just bump [level]` | Bump version (patch, minor, major) | -| `just tag` | Create git tag for current version | -| `just release` | Tag and push release to origin | +Run `just` to see all available commands. ## Code Style -- **Formatting**: `ruff format` (black compatible) +- **Formatting**: `ruff format` - **Linting**: `ruff check` - **Type Checking**: `ty check` -- **Imports**: Standard → Third-party → Local -- **Testing**: Add `pytest` tests for new functionality - -## Environment Variables - -| Variable | Default | Description | -|----------|---------|-------------| -| `OLLAMA_API_BASE` | `http://localhost:11434` | Ollama server URL | -| `OLLAMA_MODEL` | `qwen3` | Model for reference agent | - -## A2A Protocol - -The `a2a_handler.client` module provides the A2A protocol implementation: - -```python -from a2a_handler.client import build_http_client, fetch_agent_card, send_message_to_agent - -async with build_http_client() as client: - card = await fetch_agent_card("http://localhost:8000", client) - response = await send_message_to_agent("http://localhost:8000", "Hello", client) -``` +- **Testing**: pytest diff --git a/README.md b/README.md index 8530728..e89bec9 100644 --- a/README.md +++ b/README.md @@ -27,74 +27,16 @@ uvx --from a2a-handler handler ## Use -Then, you can use Handler: - ```bash -handler +handler --help ``` -If you don't have an A2A server to connect to, Handler provides a local A2A server agent: +To start a local A2A server agent (requires [Ollama](https://ollama.com/)): ```bash handler server ``` -> The server agent requires [Ollama](https://ollama.com/) to be running locally. By default it connects to `http://localhost:11434` and uses the `qwen3` model. -> -> 1. Install and run Ollama -> 2. Pull the model: `ollama pull qwen3` -> 3. (Optional) Configure via environment variables: `OLLAMA_API_BASE` and `OLLAMA_MODEL` - -### TUI - -Interactive terminal user interface: - -```bash -handler tui -``` - -### CLI - -#### Global Options - -```bash -handler --verbose # Enable verbose logging output -handler --debug # Enable debug logging output -handler --help # Show help for any command -``` - -#### Commands - -Fetch agent card from A2A server: - -```bash -handler card http://localhost:8000 -handler card http://localhost:8000 --output json # JSON output -``` - -Validate an agent card from a URL or file: - -```bash -handler validate http://localhost:8000 # Validate from URL -handler validate ./agent-card.json # Validate from file -handler validate http://localhost:8000 --output json # JSON output -``` - -Send a message to an A2A agent: - -```bash -handler send http://localhost:8000 "Hello World" -handler send http://localhost:8000 "Hello" --output json # JSON output -handler send http://localhost:8000 "Hello" --context-id abc # Conversation continuity -handler send http://localhost:8000 "Hello" --task-id xyz # Reference existing task -``` - -Show version: - -```bash -handler version -``` - ## Contributing -See [CONTRIBUTING.md](CONTRIBUTING.md) for architecture and development instructions. +See [CONTRIBUTING.md](CONTRIBUTING.md). From 05ec2d53cbabe99da7cfda76bd5c305f566f40dc Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sun, 7 Dec 2025 20:22:34 -0500 Subject: [PATCH 06/23] refactor: cleanup, organize code, and test --- src/a2a_handler/cli.py | 41 ++- src/a2a_handler/client.py | 147 -------- src/a2a_handler/components/agent_card.py | 231 ------------ src/a2a_handler/push_server.py | 166 --------- src/a2a_handler/server.py | 102 +++--- .../{a2a_service.py => service.py} | 339 ++++++++++-------- src/a2a_handler/session.py | 108 +++--- src/a2a_handler/tui.py | 240 ------------- src/a2a_handler/tui/__init__.py | 5 + src/a2a_handler/tui/app.py | 252 +++++++++++++ src/a2a_handler/{tui.tcss => tui/app.tcss} | 0 .../{ => tui}/components/__init__.py | 2 +- src/a2a_handler/tui/components/card.py | 266 ++++++++++++++ .../{ => tui}/components/contact.py | 23 +- .../{ => tui}/components/footer.py | 2 + src/a2a_handler/{ => tui}/components/input.py | 13 +- .../{ => tui}/components/messages.py | 45 +-- src/a2a_handler/validation.py | 209 +++++------ src/a2a_handler/webhook.py | 177 +++++++++ tests/test_service.py | 165 +++++++++ tests/test_session.py | 185 ++++++++++ tests/test_tui.py | 1 + tests/test_validation.py | 12 +- tests/test_webhook.py | 158 ++++++++ 24 files changed, 1706 insertions(+), 1183 deletions(-) delete mode 100644 src/a2a_handler/client.py delete mode 100644 src/a2a_handler/components/agent_card.py delete mode 100644 src/a2a_handler/push_server.py rename src/a2a_handler/{a2a_service.py => service.py} (51%) delete mode 100644 src/a2a_handler/tui.py create mode 100644 src/a2a_handler/tui/__init__.py create mode 100644 src/a2a_handler/tui/app.py rename src/a2a_handler/{tui.tcss => tui/app.tcss} (100%) rename src/a2a_handler/{ => tui}/components/__init__.py (87%) create mode 100644 src/a2a_handler/tui/components/card.py rename src/a2a_handler/{ => tui}/components/contact.py (54%) rename src/a2a_handler/{ => tui}/components/footer.py (90%) rename src/a2a_handler/{ => tui}/components/input.py (74%) rename src/a2a_handler/{ => tui}/components/messages.py (59%) create mode 100644 src/a2a_handler/webhook.py create mode 100644 tests/test_service.py create mode 100644 tests/test_session.py create mode 100644 tests/test_webhook.py diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 3c1d8fd..ded8ee5 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -103,13 +103,8 @@ A2AClientTimeoutError, ) -from a2a_handler.a2a_service import A2AService, SendResult, TaskResult # noqa: E402 -from a2a_handler.client import ( # noqa: E402 - build_http_client, - fetch_agent_card, -) -from a2a_handler.push_server import run_webhook_server # noqa: E402 from a2a_handler.server import run_server # noqa: E402 +from a2a_handler.service import A2AService, SendResult, TaskResult # noqa: E402 from a2a_handler.session import ( # noqa: E402 clear_session, get_session, @@ -122,6 +117,15 @@ validate_agent_card_from_file, validate_agent_card_from_url, ) +from a2a_handler.webhook import run_webhook_server # noqa: E402 + +TIMEOUT = 120 + + +def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: + """Build an HTTP client with the specified timeout.""" + return httpx.AsyncClient(timeout=timeout) + log = get_logger(__name__) @@ -324,7 +328,8 @@ async def fetch() -> None: log.debug("Building HTTP client") async with build_http_client() as client: log.debug("Requesting agent card") - card_data = await fetch_agent_card(agent_url, client) + service = A2AService(client, agent_url) + card_data = await service.get_card() log.info("Retrieved card for agent: %s", card_data.name) if output == "json": @@ -377,12 +382,20 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: "agentName": result.agent_name, "protocolVersion": result.protocol_version, "issues": [ - {"field": i.field, "message": i.message, "type": i.issue_type} - for i in result.issues + { + "field": issue.field_name, + "message": issue.message, + "type": issue.issue_type, + } + for issue in result.issues ], "warnings": [ - {"field": w.field, "message": w.message, "type": w.issue_type} - for w in result.warnings + { + "field": warning.field_name, + "message": warning.message, + "type": warning.issue_type, + } + for warning in result.warnings ], } print_json(json.dumps(output_data, indent=2)) @@ -403,7 +416,7 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: ) for warning in result.warnings: content_parts.append( - f" [yellow]⚠[/yellow] {warning.field}: {warning.message}" + f" [yellow]⚠[/yellow] {warning.field_name}: {warning.message}" ) print_panel("\n".join(content_parts), title=title) @@ -416,7 +429,7 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: ] for issue in result.issues: - content_parts.append(f" [red]✗[/red] {issue.field}: {issue.message}") + content_parts.append(f" [red]✗[/red] {issue.field_name}: {issue.message}") print_panel("\n".join(content_parts), title=title) @@ -526,7 +539,7 @@ async def send_msg() -> None: service = A2AService( http_client, agent_url, - streaming=stream, + enable_streaming=stream, push_notification_url=push_url, push_notification_token=push_token, ) diff --git a/src/a2a_handler/client.py b/src/a2a_handler/client.py deleted file mode 100644 index 9786a54..0000000 --- a/src/a2a_handler/client.py +++ /dev/null @@ -1,147 +0,0 @@ -"""A2A protocol client utilities. - -This module provides backwards-compatible functions that wrap A2AService. -For new code, use A2AService directly. -""" - -from dataclasses import dataclass -from typing import Any - -import httpx -from a2a.types import AgentCard - -from a2a_handler.a2a_service import A2AService -from a2a_handler.common import get_logger - -log = get_logger(__name__) - -TIMEOUT = 120 - - -def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: - """Build an HTTP client with the specified timeout. - - Args: - timeout: Request timeout in seconds - - Returns: - Configured HTTP client - """ - return httpx.AsyncClient(timeout=timeout) - - -async def fetch_agent_card(agent_url: str, client: httpx.AsyncClient) -> AgentCard: - """Fetch agent card from the specified URL. - - Args: - agent_url: The base URL of the agent - client: HTTP client to use for the request - - Returns: - The agent's card with metadata and capabilities - - Raises: - httpx.RequestError: If the request fails - """ - service = A2AService(client, agent_url) - return await service.get_card() - - -async def send_message_to_agent( - agent_url: str, - message_text: str, - client: httpx.AsyncClient, - context_id: str | None = None, - task_id: str | None = None, -) -> dict[str, Any]: - """Send a message to an agent and return the response. - - Args: - agent_url: The base URL of the agent - message_text: The message to send - client: HTTP client to use - context_id: Optional context ID for conversation continuity - task_id: Optional task ID to reference - - Returns: - Response data as a dictionary - - Raises: - httpx.RequestError: If the request fails - httpx.TimeoutException: If the request times out - """ - service = A2AService(client, agent_url) - result = await service.send(message_text, context_id, task_id) - return result.raw - - -@dataclass -class ParsedResponse: - """Parsed A2A response with extracted text content.""" - - text: str - raw: dict[str, Any] - context_id: str | None = None - task_id: str | None = None - - @property - def has_content(self) -> bool: - """Check if the response has meaningful content.""" - return bool(self.text) - - -def parse_response(response: dict[str, Any]) -> ParsedResponse: - """Parse an A2A response and extract text content. - - Args: - response: Raw response dictionary from send_message_to_agent - - Returns: - ParsedResponse with extracted text and raw data - """ - if not response: - log.debug("Empty response received") - return ParsedResponse(text="", raw=response) - - texts: list[str] = [] - context_id = response.get("context_id") - task_id = response.get("id") or response.get("task_id") - - if "parts" in response: - for p in response["parts"]: - if isinstance(p, dict): - if "root" in p and isinstance(p["root"], dict): - texts.append(p["root"].get("text", "")) - else: - texts.append(p.get("text", "")) - log.debug("Extracted %d parts from response", len(response["parts"])) - - for artifact in response.get("artifacts", []): - artifact_parts = artifact.get("parts", []) - for p in artifact_parts: - if isinstance(p, dict): - if "root" in p and isinstance(p["root"], dict): - texts.append(p["root"].get("text", "")) - else: - texts.append(p.get("text", "")) - log.debug("Extracted %d parts from artifact", len(artifact_parts)) - - if "history" in response: - for msg in response["history"]: - if msg.get("role") == "agent": - for p in msg.get("parts", []): - if isinstance(p, dict): - if "root" in p and isinstance(p["root"], dict): - texts.append(p["root"].get("text", "")) - else: - texts.append(p.get("text", "")) - - text = "\n".join(t for t in texts if t) - log.debug("Parsed response with %d characters", len(text)) - - return ParsedResponse( - text=text, - raw=response, - context_id=context_id, - task_id=task_id, - ) diff --git a/src/a2a_handler/components/agent_card.py b/src/a2a_handler/components/agent_card.py deleted file mode 100644 index 1cddb4d..0000000 --- a/src/a2a_handler/components/agent_card.py +++ /dev/null @@ -1,231 +0,0 @@ -import json -import logging -import re -from typing import Any - -from a2a.types import AgentCard -from rich.syntax import Syntax -from textual.app import ComposeResult -from textual.binding import Binding -from textual.containers import Container, VerticalScroll -from textual.widgets import Static, TabbedContent, TabPane, Tabs - -logger = logging.getLogger(__name__) - -TEXTUAL_TO_SYNTAX_THEME: dict[str, str] = { - "gruvbox": "gruvbox-dark", - "nord": "nord", - "tokyo-night": "monokai", - "textual-dark": "monokai", - "textual-light": "default", - "solarized-light": "solarized-light", - "dracula": "dracula", - "catppuccin-mocha": "monokai", - "monokai": "monokai", -} - - -class AgentCardPanel(Container): - """Panel displaying agent card information with tabs.""" - - BINDINGS = [ - Binding("h", "prev_tab", "Prev Tab", show=False), - Binding("l", "next_tab", "Next Tab", show=False), - Binding("left", "prev_tab", "Prev Tab", show=False), - Binding("right", "next_tab", "Next Tab", show=False), - Binding("j", "scroll_down", "Scroll Down", show=False), - Binding("k", "scroll_up", "Scroll Up", show=False), - Binding("down", "scroll_down", "Scroll Down", show=False), - Binding("up", "scroll_up", "Scroll Up", show=False), - ] - - can_focus = True - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._agent_card: AgentCard | None = None - - def compose(self) -> ComposeResult: - with TabbedContent(id="agent-card-tabs"): - with TabPane("Short", id="short-tab"): - yield VerticalScroll( - Static("Not connected", id="agent-short"), - id="short-scroll", - ) - with TabPane("Long", id="long-tab"): - yield VerticalScroll( - Static("", id="agent-long"), - id="long-scroll", - ) - with TabPane("Raw", id="raw-tab"): - yield VerticalScroll( - Static("", id="agent-raw"), - id="raw-scroll", - ) - - def on_mount(self) -> None: - self.border_title = "AGENT CARD" - self.border_subtitle = "READY" - for widget in self.query("TabbedContent, Tabs, Tab, TabPane, VerticalScroll"): - widget.can_focus = False - - def _get_syntax_theme(self) -> str: - """Get the Rich Syntax theme name for the current app theme.""" - return TEXTUAL_TO_SYNTAX_THEME.get(self.app.theme or "", "monokai") - - def _format_key(self, key: str) -> str: - """Convert a key to sentence case.""" - spaced = re.sub(r"([a-z])([A-Z])", r"\1 \2", key) - return spaced.replace("_", " ").capitalize() - - def _is_empty(self, value: Any) -> bool: - """Check if a value is truly empty, including nested structures.""" - if value is None: - return True - if isinstance(value, (str, list, dict)) and not value: - return True - if isinstance(value, dict): - return all(self._is_empty(v) for v in value.values()) - if isinstance(value, list): - return all(self._is_empty(v) for v in value) - return False - - def _format_value(self, value: Any, indent: int = 0) -> str: - """Format a nested value for display.""" - prefix = " " * indent - if isinstance(value, dict): - lines = [] - for k, v in value.items(): - if self._is_empty(v): - continue - formatted_key = self._format_key(k) - if isinstance(v, (list, dict)): - lines.append(f"{prefix}[bold]{formatted_key}[/]") - lines.append(self._format_value(v, indent + 1)) - else: - lines.append(f"{prefix}[bold]{formatted_key}:[/] {v}") - return "\n".join(lines) - if isinstance(value, list): - lines = [] - for item in value: - if self._is_empty(item): - continue - if isinstance(item, dict): - lines.append(self._format_value(item, indent)) - else: - lines.append(f"{prefix}• {item}") - return "\n".join(lines) - return f"{prefix}{value}" - - def _build_short_view(self, card: AgentCard) -> str: - """Build the short view with essential fields only.""" - card_dict = card.model_dump() - lines = [] - - short_fields = [ - "name", - "description", - "version", - "url", - "defaultInputModes", - "defaultOutputModes", - ] - - for key in short_fields: - value = card_dict.get(key) - if self._is_empty(value): - continue - formatted_key = self._format_key(key) - if isinstance(value, (list, dict)): - lines.append(f"[bold]{formatted_key}[/]") - lines.append(self._format_value(value, indent=1)) - else: - lines.append(f"[bold]{formatted_key}:[/] {value}") - - return "\n".join(lines) - - def _build_long_view(self, card: AgentCard) -> str: - """Build the long view with all non-empty fields.""" - card_dict = card.model_dump() - lines = [] - - for key, value in card_dict.items(): - if self._is_empty(value): - continue - formatted_key = self._format_key(key) - if isinstance(value, (list, dict)): - lines.append(f"[bold]{formatted_key}[/]") - lines.append(self._format_value(value, indent=1)) - else: - lines.append(f"[bold]{formatted_key}:[/] {value}") - - return "\n".join(lines) - - def update_card(self, card: AgentCard | None) -> None: - """Update the displayed agent card.""" - self._agent_card = card - - if card is None: - self.query_one("#agent-short", Static).update("Not connected") - self.query_one("#agent-long", Static).update("") - self.query_one("#agent-raw", Static).update("") - self.border_subtitle = "READY" - else: - self.query_one("#agent-short", Static).update(self._build_short_view(card)) - self.query_one("#agent-long", Static).update(self._build_long_view(card)) - - json_str = json.dumps(card.model_dump(), indent=2, default=str) - self.query_one("#agent-raw", Static).update( - Syntax(json_str, "json", theme=self._get_syntax_theme()) - ) - self.border_subtitle = "ACTIVE" - - def refresh_theme(self) -> None: - """Refresh the raw view syntax highlighting for theme changes.""" - if self._agent_card is None: - return - json_str = json.dumps(self._agent_card.model_dump(), indent=2, default=str) - self.query_one("#agent-raw", Static).update( - Syntax(json_str, "json", theme=self._get_syntax_theme()) - ) - - def _get_active_scroll(self) -> VerticalScroll | None: - """Get the currently visible scroll container.""" - tabs = self.query_one("#agent-card-tabs", TabbedContent) - active_tab = tabs.active - - if active_tab == "short-tab": - return self.query_one("#short-scroll", VerticalScroll) - elif active_tab == "long-tab": - return self.query_one("#long-scroll", VerticalScroll) - elif active_tab == "raw-tab": - return self.query_one("#raw-scroll", VerticalScroll) - return None - - def action_prev_tab(self) -> None: - """Switch to the previous tab.""" - try: - tabs = self.query_one("#agent-card-tabs Tabs", Tabs) - tabs.action_previous_tab() - except Exception: - pass - - def action_next_tab(self) -> None: - """Switch to the next tab.""" - try: - tabs = self.query_one("#agent-card-tabs Tabs", Tabs) - tabs.action_next_tab() - except Exception: - pass - - def action_scroll_down(self) -> None: - """Scroll down in the active tab's scroll container.""" - scroll = self._get_active_scroll() - if scroll: - scroll.scroll_down() - - def action_scroll_up(self) -> None: - """Scroll up in the active tab's scroll container.""" - scroll = self._get_active_scroll() - if scroll: - scroll.scroll_up() diff --git a/src/a2a_handler/push_server.py b/src/a2a_handler/push_server.py deleted file mode 100644 index e80fe55..0000000 --- a/src/a2a_handler/push_server.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Local webhook server for receiving A2A push notifications. - -This module provides a simple HTTP server that can receive push notifications -from A2A agents for testing purposes. -""" - -import json -from collections import deque -from dataclasses import dataclass, field -from datetime import datetime -from typing import Any - -import uvicorn -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import JSONResponse -from starlette.routing import Route - -from a2a_handler.common import console, get_logger - -log = get_logger(__name__) - -MAX_NOTIFICATIONS = 100 - - -@dataclass -class Notification: - """A received push notification.""" - - timestamp: datetime - task_id: str | None - payload: dict[str, Any] - headers: dict[str, str] - - -@dataclass -class NotificationStore: - """In-memory store for received notifications.""" - - notifications: deque[Notification] = field( - default_factory=lambda: deque(maxlen=MAX_NOTIFICATIONS) - ) - - def add(self, notification: Notification) -> None: - """Add a notification to the store.""" - self.notifications.append(notification) - - def get_all(self) -> list[Notification]: - """Get all stored notifications.""" - return list(self.notifications) - - def clear(self) -> None: - """Clear all stored notifications.""" - self.notifications.clear() - - -notification_store = NotificationStore() - - -async def handle_notification(request: Request) -> JSONResponse: - """Handle incoming push notifications from A2A agents.""" - try: - payload = await request.json() - except json.JSONDecodeError: - log.warning("Received invalid JSON in push notification") - return JSONResponse({"error": "Invalid JSON"}, status_code=400) - - headers = dict(request.headers) - task_id = payload.get("id") or payload.get("task_id") - - notification = Notification( - timestamp=datetime.now(), - task_id=task_id, - payload=payload, - headers=headers, - ) - notification_store.add(notification) - - log.info("Received push notification for task: %s", task_id) - - console.print("\n[bold cyan]Push Notification Received[/bold cyan]") - console.print(f"[dim]Timestamp:[/dim] {notification.timestamp.isoformat()}") - if task_id: - console.print(f"[dim]Task ID:[/dim] {task_id}") - - status = payload.get("status", {}) - if status: - state = status.get("state", "unknown") - console.print(f"[dim]State:[/dim] {state}") - - token = headers.get("x-a2a-notification-token") - if token: - console.print(f"[dim]Token:[/dim] {token[:20]}...") - - console.print() - console.print_json(json.dumps(payload, indent=2, default=str)) - console.print() - - return JSONResponse({"status": "ok", "received": True}) - - -async def handle_validation(request: Request) -> JSONResponse: - """Handle GET requests for webhook validation.""" - log.info("Webhook validation request received") - return JSONResponse({"status": "ok", "message": "Webhook is active"}) - - -async def handle_list(request: Request) -> JSONResponse: - """List all received notifications.""" - notifications = notification_store.get_all() - return JSONResponse( - { - "count": len(notifications), - "notifications": [ - { - "timestamp": n.timestamp.isoformat(), - "task_id": n.task_id, - "payload": n.payload, - } - for n in notifications - ], - } - ) - - -async def handle_clear(request: Request) -> JSONResponse: - """Clear all stored notifications.""" - notification_store.clear() - log.info("Cleared all stored notifications") - return JSONResponse({"status": "ok", "message": "Notifications cleared"}) - - -def create_webhook_app() -> Starlette: - """Create the webhook Starlette application.""" - routes = [ - Route("/webhook", handle_notification, methods=["POST"]), - Route("/webhook", handle_validation, methods=["GET"]), - Route("/notifications", handle_list, methods=["GET"]), - Route("/notifications/clear", handle_clear, methods=["POST"]), - ] - return Starlette(routes=routes) - - -def run_webhook_server(host: str = "127.0.0.1", port: int = 9000) -> None: - """Start the webhook server. - - Args: - host: Host address to bind to - port: Port number to bind to - """ - console.print(f"\n[bold]Starting webhook server on [url]{host}:{port}[/url][/bold]") - console.print() - console.print("[dim]Endpoints:[/dim]") - console.print(f" POST http://{host}:{port}/webhook - Receive notifications") - console.print(f" GET http://{host}:{port}/webhook - Validation check") - console.print(f" GET http://{host}:{port}/notifications - List received") - console.print(f" POST http://{host}:{port}/notifications/clear - Clear stored") - console.print() - console.print( - f"[bold green]Use this URL for push notifications:[/bold green] " - f"http://{host}:{port}/webhook" - ) - console.print() - - app = create_webhook_app() - uvicorn.run(app, host=host, port=port, log_level="warning") diff --git a/src/a2a_handler/server.py b/src/a2a_handler/server.py index b8a7177..f404576 100644 --- a/src/a2a_handler/server.py +++ b/src/a2a_handler/server.py @@ -9,9 +9,9 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import ( + BasePushNotificationSender, InMemoryPushNotificationConfigStore, InMemoryTaskStore, - BasePushNotificationSender, ) from a2a.types import AgentCapabilities, AgentCard, AgentSkill from dotenv import load_dotenv @@ -30,10 +30,14 @@ from a2a_handler.common import console, get_logger, setup_logging setup_logging(level="INFO", suppress_libs=["uvicorn", "google"]) -log = get_logger(__name__) +logger = get_logger(__name__) + +DEFAULT_OLLAMA_API_BASE = "http://localhost:11434" +DEFAULT_OLLAMA_MODEL = "qwen3" +DEFAULT_HTTP_TIMEOUT_SECONDS = 30 -def create_agent() -> Agent: +def create_llm_agent() -> Agent: """Create and configure the A2A test agent using LiteLLM with Ollama. Returns: @@ -41,24 +45,24 @@ def create_agent() -> Agent: """ load_dotenv() - ollama_base = os.getenv("OLLAMA_API_BASE", "http://localhost:11434") - ollama_model = os.getenv("OLLAMA_MODEL", "qwen3") + ollama_api_base = os.getenv("OLLAMA_API_BASE", DEFAULT_OLLAMA_API_BASE) + ollama_model_name = os.getenv("OLLAMA_MODEL", DEFAULT_OLLAMA_MODEL) - log.info( + logger.info( "Creating agent with model: [highlight]%s[/highlight] at [url]%s[/url]", - ollama_model, - ollama_base, + ollama_model_name, + ollama_api_base, ) - model = LiteLlm( - model=f"ollama_chat/{ollama_model}", - api_base=ollama_base, + language_model = LiteLlm( + model=f"ollama_chat/{ollama_model_name}", + api_base=ollama_api_base, reasoning_effort="none", ) agent = Agent( name="Handler", - model=model, + model=language_model, description="Handler assistant", instruction="""You are Handler, the resident helpful agent for the Handler application. You are an expert on the Handler toolkit, which is a terminal-based system for communicating with and testing Agent-to-Agent (A2A) protocol agents. @@ -73,7 +77,7 @@ def create_agent() -> Agent: You are proud to be an A2A server agent.""", ) - log.info( + logger.info( "[success]Agent created successfully:[/success] [agent]%s[/agent]", agent.name ) return agent @@ -90,12 +94,12 @@ def build_agent_card(agent: Agent, host: str, port: int) -> AgentCard: Returns: Configured AgentCard with capabilities enabled """ - capabilities = AgentCapabilities( + agent_capabilities = AgentCapabilities( streaming=True, push_notifications=True, ) - skill = AgentSkill( + agent_skill = AgentSkill( id="handler_assistant", name="Handler Assistant", description="Answers questions about the Handler A2A toolkit and helps with usage", @@ -103,15 +107,17 @@ def build_agent_card(agent: Agent, host: str, port: int) -> AgentCard: examples=["What is Handler?", "How do I use the CLI?", "Tell me about A2A"], ) - rpc_url = f"http://{host}:{port}/" + rpc_endpoint_url = f"http://{host}:{port}/" + + logger.debug("Building agent card with RPC URL: %s", rpc_endpoint_url) return AgentCard( name=agent.name, description=agent.description or "Handler A2A agent", - url=rpc_url, + url=rpc_endpoint_url, version="1.0.0", - capabilities=capabilities, - skills=[skill], + capabilities=agent_capabilities, + skills=[agent_skill], default_input_modes=["text/plain"], default_output_modes=["text/plain"], ) @@ -140,7 +146,7 @@ async def create_runner() -> Runner: return create_runner -def create_a2a_app(agent: Agent, agent_card: AgentCard) -> Starlette: +def create_a2a_application(agent: Agent, agent_card: AgentCard) -> Starlette: """Create a Starlette A2A application with full push notification support. This is a custom implementation that replaces google-adk's to_a2a() to add @@ -156,9 +162,11 @@ def create_a2a_app(agent: Agent, agent_card: AgentCard) -> Starlette: Configured Starlette application """ task_store = InMemoryTaskStore() - push_config_store = InMemoryPushNotificationConfigStore() - http_client = httpx.AsyncClient(timeout=30.0) - push_sender = BasePushNotificationSender(http_client, push_config_store) + push_notification_config_store = InMemoryPushNotificationConfigStore() + http_client = httpx.AsyncClient(timeout=DEFAULT_HTTP_TIMEOUT_SECONDS) + push_notification_sender = BasePushNotificationSender( + http_client, push_notification_config_store + ) agent_executor = A2aAgentExecutor( runner=create_runner_factory(agent), @@ -167,28 +175,28 @@ def create_a2a_app(agent: Agent, agent_card: AgentCard) -> Starlette: request_handler = DefaultRequestHandler( agent_executor=agent_executor, task_store=task_store, - push_config_store=push_config_store, - push_sender=push_sender, + push_config_store=push_notification_config_store, + push_sender=push_notification_sender, ) - app = Starlette() + application = Starlette() - async def setup_a2a() -> None: - a2a_app = A2AStarletteApplication( + async def setup_a2a_routes() -> None: + a2a_starlette_app = A2AStarletteApplication( agent_card=agent_card, http_handler=request_handler, ) - a2a_app.add_routes_to_app(app) - log.info("A2A routes configured with push notification support") + a2a_starlette_app.add_routes_to_app(application) + logger.info("A2A routes configured with push notification support") - async def cleanup() -> None: + async def cleanup_http_client() -> None: await http_client.aclose() - log.info("HTTP client closed") + logger.info("HTTP client closed") - app.add_event_handler("startup", setup_a2a) - app.add_event_handler("shutdown", cleanup) + application.add_event_handler("startup", setup_a2a_routes) + application.add_event_handler("shutdown", cleanup_http_client) - return app + return application def run_server(host: str, port: int) -> None: @@ -201,20 +209,26 @@ def run_server(host: str, port: int) -> None: console.print( f"\n[bold]Starting Handler server on [url]{host}:{port}[/url][/bold]\n" ) - log.info("Initializing A2A server with push notification support...") - agent = create_agent() + logger.info("Initializing A2A server with push notification support...") + agent = create_llm_agent() agent_card = build_agent_card(agent, host, port) - log.info( + + streaming_enabled = ( + agent_card.capabilities.streaming if agent_card.capabilities else False + ) + push_notifications_enabled = ( + agent_card.capabilities.push_notifications if agent_card.capabilities else False + ) + + logger.info( "Agent card capabilities: streaming=%s, push_notifications=%s", - agent_card.capabilities.streaming if agent_card.capabilities else False, - agent_card.capabilities.push_notifications - if agent_card.capabilities - else False, + streaming_enabled, + push_notifications_enabled, ) - a2a_app = create_a2a_app(agent, agent_card) - uvicorn.run(a2a_app, host=host, port=port) + a2a_application = create_a2a_application(agent, agent_card) + uvicorn.run(a2a_application, host=host, port=port) @click.command() diff --git a/src/a2a_handler/a2a_service.py b/src/a2a_handler/service.py similarity index 51% rename from src/a2a_handler/a2a_service.py rename to src/a2a_handler/service.py index b85e885..1bae5d7 100644 --- a/src/a2a_handler/a2a_service.py +++ b/src/a2a_handler/service.py @@ -29,9 +29,9 @@ from a2a_handler.common import get_logger -log = get_logger(__name__) +logger = get_logger(__name__) -TERMINAL_STATES = { +TERMINAL_TASK_STATES = { TaskState.completed, TaskState.canceled, TaskState.failed, @@ -54,7 +54,7 @@ class SendResult: @property def is_complete(self) -> bool: """Check if the task reached a terminal state.""" - return self.state in TERMINAL_STATES if self.state else False + return self.state in TERMINAL_TASK_STATES if self.state else False @property def needs_input(self) -> bool: @@ -66,7 +66,7 @@ def needs_input(self) -> bool: class StreamEvent: """A single event from a streaming response.""" - event_type: str # "task", "message", "status", "artifact" + event_type: str task: Task | None = None message: Message | None = None status: TaskStatusUpdateEvent | None = None @@ -89,31 +89,36 @@ class TaskResult: raw: dict[str, Any] = field(default_factory=dict) -def _extract_text_from_parts(parts: list[Part] | None) -> str: +def extract_text_from_message_parts(message_parts: list[Part] | None) -> str: """Extract text content from message parts.""" - if not parts: + if not message_parts: return "" - texts = [] - for part in parts: + + extracted_texts = [] + for part in message_parts: if hasattr(part, "root") and hasattr(part.root, "text"): - texts.append(part.root.text) + extracted_texts.append(part.root.text) elif hasattr(part, "text"): - texts.append(part.text) - return "\n".join(t for t in texts if t) + extracted_texts.append(part.text) + + return "\n".join(text for text in extracted_texts if text) -def _extract_text_from_task(task: Task) -> str: +def extract_text_from_task(task: Task) -> str: """Extract text from task artifacts and history.""" - texts = [] + extracted_texts = [] + if task.artifacts: for artifact in task.artifacts: if artifact.parts: - texts.append(_extract_text_from_parts(artifact.parts)) + extracted_texts.append(extract_text_from_message_parts(artifact.parts)) + if task.history: - for msg in task.history: - if msg.role == Role.agent and msg.parts: - texts.append(_extract_text_from_parts(msg.parts)) - return "\n".join(t for t in texts if t) + for message in task.history: + if message.role == Role.agent and message.parts: + extracted_texts.append(extract_text_from_message_parts(message.parts)) + + return "\n".join(text for text in extracted_texts if text) class A2AService: @@ -127,26 +132,26 @@ def __init__( self, http_client: httpx.AsyncClient, agent_url: str, - streaming: bool = True, + enable_streaming: bool = True, push_notification_url: str | None = None, push_notification_token: str | None = None, - ): + ) -> None: """Initialize the A2A service. Args: http_client: Async HTTP client to use for requests agent_url: Base URL of the A2A agent - streaming: Whether to prefer streaming when available + enable_streaming: Whether to prefer streaming when available push_notification_url: Optional webhook URL for push notifications push_notification_token: Optional token for push notification auth """ self.http_client = http_client self.agent_url = agent_url - self.streaming = streaming + self.enable_streaming = enable_streaming self.push_notification_url = push_notification_url self.push_notification_token = push_notification_token - self._client: Client | None = None - self._card: AgentCard | None = None + self._cached_client: Client | None = None + self._cached_agent_card: AgentCard | None = None async def get_card(self) -> AgentCard: """Fetch and cache the agent card. @@ -154,65 +159,71 @@ async def get_card(self) -> AgentCard: Returns: The agent's card with metadata and capabilities """ - if self._card is None: - log.info("Fetching agent card from %s", self.agent_url) - resolver = A2ACardResolver(self.http_client, self.agent_url) - self._card = await resolver.get_agent_card() - log.info("Connected to agent: %s", self._card.name) - return self._card - - async def _get_client(self) -> Client: + if self._cached_agent_card is None: + logger.info("Fetching agent card from %s", self.agent_url) + card_resolver = A2ACardResolver(self.http_client, self.agent_url) + self._cached_agent_card = await card_resolver.get_agent_card() + logger.info("Connected to agent: %s", self._cached_agent_card.name) + return self._cached_agent_card + + async def _get_or_create_client(self) -> Client: """Get or create the A2A client. Returns: Configured A2A client instance """ - if self._client is None: - card = await self.get_card() - push_configs: list[PushNotificationConfig] = [] + if self._cached_client is None: + agent_card = await self.get_card() + + push_notification_configs: list[PushNotificationConfig] = [] if self.push_notification_url: - push_configs.append( + push_notification_configs.append( PushNotificationConfig( url=self.push_notification_url, token=self.push_notification_token, ) ) - log.info("Push notification config: %s", self.push_notification_url) - config = ClientConfig( + logger.info( + "Push notification configured: %s", self.push_notification_url + ) + + client_config = ClientConfig( httpx_client=self.http_client, supported_transports=[TransportProtocol.jsonrpc], - streaming=self.streaming, - push_notification_configs=push_configs, + streaming=self.enable_streaming, + push_notification_configs=push_notification_configs, ) - factory = ClientFactory(config) - self._client = factory.create(card) - log.debug("Created A2A client for %s", card.name) - return self._client + + client_factory = ClientFactory(client_config) + self._cached_client = client_factory.create(agent_card) + logger.debug("Created A2A client for %s", agent_card.name) + + return self._cached_client @property def supports_streaming(self) -> bool: """Check if the agent supports streaming.""" - if self._card and self._card.capabilities: - return bool(self._card.capabilities.streaming) + if self._cached_agent_card and self._cached_agent_card.capabilities: + return bool(self._cached_agent_card.capabilities.streaming) return False @property def supports_push_notifications(self) -> bool: """Check if the agent supports push notifications.""" - if self._card and self._card.capabilities: - return bool(self._card.capabilities.push_notifications) + if self._cached_agent_card and self._cached_agent_card.capabilities: + return bool(self._cached_agent_card.capabilities.push_notifications) return False - def _build_message( + def _build_user_message( self, - text: str, + message_text: str, context_id: str | None = None, task_id: str | None = None, ) -> Message: """Build a user message. Args: - text: Message content + message_text: Message content context_id: Optional context ID for conversation continuity task_id: Optional task ID to continue @@ -222,14 +233,14 @@ def _build_message( return Message( message_id=str(uuid.uuid4()), role=Role.user, - parts=[Part(root=TextPart(text=text))], + parts=[Part(root=TextPart(text=message_text))], context_id=context_id, task_id=task_id, ) async def send( self, - text: str, + message_text: str, context_id: str | None = None, task_id: str | None = None, ) -> SendResult: @@ -238,45 +249,50 @@ async def send( This method collects all streaming events and returns the final result. Args: - text: Message to send + message_text: Message to send context_id: Optional context ID for conversation continuity task_id: Optional task ID to continue Returns: SendResult with task state, extracted text, and IDs """ - client = await self._get_client() - message = self._build_message(text, context_id, task_id) + client = await self._get_or_create_client() + user_message = self._build_user_message(message_text, context_id, task_id) - log.info("Sending message: %s", text[:50] if len(text) > 50 else text) + truncated_message = ( + message_text[:50] if len(message_text) > 50 else message_text + ) + logger.info("Sending message: %s", truncated_message) result = SendResult() - last_task: Task | None = None + last_received_task: Task | None = None - async for event in client.send_message(message): + async for event in client.send_message(user_message): if isinstance(event, Message): result.message = event result.context_id = event.context_id result.task_id = event.task_id - result.text = _extract_text_from_parts(event.parts) - log.debug("Received message response") + result.text = extract_text_from_message_parts(event.parts) + logger.debug("Received message response") elif isinstance(event, tuple): - task, update = event - last_task = task - result.task = task - result.task_id = task.id - result.context_id = task.context_id - if task.status: - result.state = task.status.state - log.debug( + received_task, task_update = event + last_received_task = received_task + result.task = received_task + result.task_id = received_task.id + result.context_id = received_task.context_id + if received_task.status: + result.state = received_task.status.state + logger.debug( "Received task update: %s", - task.status.state if task.status else "unknown", + received_task.status.state if received_task.status else "unknown", ) - if last_task: - result.text = _extract_text_from_task(last_task) + if last_received_task: + result.text = extract_text_from_task(last_received_task) result.raw = ( - last_task.model_dump() if hasattr(last_task, "model_dump") else {} + last_received_task.model_dump() + if hasattr(last_received_task, "model_dump") + else {} ) elif result.message: result.raw = ( @@ -285,75 +301,84 @@ async def send( else {} ) - log.info("Send complete: task_id=%s, state=%s", result.task_id, result.state) + logger.info("Send complete: task_id=%s, state=%s", result.task_id, result.state) return result async def stream( self, - text: str, + message_text: str, context_id: str | None = None, task_id: str | None = None, ) -> AsyncIterator[StreamEvent]: """Send a message and stream responses as they arrive. Args: - text: Message to send + message_text: Message to send context_id: Optional context ID for conversation continuity task_id: Optional task ID to continue Yields: StreamEvent objects as they are received """ - client = await self._get_client() - message = self._build_message(text, context_id, task_id) + client = await self._get_or_create_client() + user_message = self._build_user_message(message_text, context_id, task_id) - log.info("Streaming message: %s", text[:50] if len(text) > 50 else text) + truncated_message = ( + message_text[:50] if len(message_text) > 50 else message_text + ) + logger.info("Streaming message: %s", truncated_message) - async for event in client.send_message(message): + async for event in client.send_message(user_message): if isinstance(event, Message): yield StreamEvent( event_type="message", message=event, context_id=event.context_id, task_id=event.task_id, - text=_extract_text_from_parts(event.parts), + text=extract_text_from_message_parts(event.parts), ) elif isinstance(event, tuple): - task, update = event - if isinstance(update, TaskStatusUpdateEvent): - status_text = "" - if update.status and update.status.message: - status_text = str(update.status.message) + received_task, task_update = event + if isinstance(task_update, TaskStatusUpdateEvent): + status_message_text = "" + if task_update.status and task_update.status.message: + status_message_text = str(task_update.status.message) yield StreamEvent( event_type="status", - task=task, - status=update, - context_id=task.context_id, - task_id=task.id, - state=update.status.state if update.status else None, - text=status_text, + task=received_task, + status=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=task_update.status.state if task_update.status else None, + text=status_message_text, ) - elif isinstance(update, TaskArtifactUpdateEvent): + elif isinstance(task_update, TaskArtifactUpdateEvent): artifact_text = "" - if update.artifact and update.artifact.parts: - artifact_text = _extract_text_from_parts(update.artifact.parts) + if task_update.artifact and task_update.artifact.parts: + artifact_text = extract_text_from_message_parts( + task_update.artifact.parts + ) yield StreamEvent( event_type="artifact", - task=task, - artifact=update, - context_id=task.context_id, - task_id=task.id, - state=task.status.state if task.status else None, + task=received_task, + artifact=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), text=artifact_text, ) else: yield StreamEvent( event_type="task", - task=task, - context_id=task.context_id, - task_id=task.id, - state=task.status.state if task.status else None, - text=_extract_text_from_task(task), + task=received_task, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), + text=extract_text_from_task(received_task), ) async def get_task( @@ -370,19 +395,19 @@ async def get_task( Returns: TaskResult with task state and details """ - client = await self._get_client() + client = await self._get_or_create_client() - params = TaskQueryParams(id=task_id, history_length=history_length) - log.info("Getting task: %s", task_id) + query_params = TaskQueryParams(id=task_id, history_length=history_length) + logger.info("Getting task: %s", task_id) - task = await client.get_task(params) + task = await client.get_task(query_params) return TaskResult( task=task, task_id=task.id, state=task.status.state if task.status else TaskState.unknown, context_id=task.context_id, - text=_extract_text_from_task(task), + text=extract_text_from_task(task), raw=task.model_dump() if hasattr(task, "model_dump") else {}, ) @@ -395,19 +420,19 @@ async def cancel_task(self, task_id: str) -> TaskResult: Returns: TaskResult with updated task state """ - client = await self._get_client() + client = await self._get_or_create_client() - params = TaskIdParams(id=task_id) - log.info("Canceling task: %s", task_id) + task_id_params = TaskIdParams(id=task_id) + logger.info("Canceling task: %s", task_id) - task = await client.cancel_task(params) + task = await client.cancel_task(task_id_params) return TaskResult( task=task, task_id=task.id, state=task.status.state if task.status else TaskState.unknown, context_id=task.context_id, - text=_extract_text_from_task(task), + text=extract_text_from_task(task), raw=task.model_dump() if hasattr(task, "model_dump") else {}, ) @@ -420,75 +445,79 @@ async def resubscribe(self, task_id: str) -> AsyncIterator[StreamEvent]: Yields: StreamEvent objects as they are received """ - client = await self._get_client() + client = await self._get_or_create_client() - params = TaskIdParams(id=task_id) - log.info("Resubscribing to task: %s", task_id) + task_id_params = TaskIdParams(id=task_id) + logger.info("Resubscribing to task: %s", task_id) - async for event in client.resubscribe(params): - task, update = event - if isinstance(update, TaskStatusUpdateEvent): + async for event in client.resubscribe(task_id_params): + received_task, task_update = event + if isinstance(task_update, TaskStatusUpdateEvent): yield StreamEvent( event_type="status", - task=task, - status=update, - context_id=task.context_id, - task_id=task.id, - state=update.status.state if update.status else None, + task=received_task, + status=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=task_update.status.state if task_update.status else None, ) - elif isinstance(update, TaskArtifactUpdateEvent): + elif isinstance(task_update, TaskArtifactUpdateEvent): artifact_text = "" - if update.artifact and update.artifact.parts: - artifact_text = _extract_text_from_parts(update.artifact.parts) + if task_update.artifact and task_update.artifact.parts: + artifact_text = extract_text_from_message_parts( + task_update.artifact.parts + ) yield StreamEvent( event_type="artifact", - task=task, - artifact=update, - context_id=task.context_id, - task_id=task.id, - state=task.status.state if task.status else None, + task=received_task, + artifact=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), text=artifact_text, ) else: yield StreamEvent( event_type="task", - task=task, - context_id=task.context_id, - task_id=task.id, - state=task.status.state if task.status else None, - text=_extract_text_from_task(task), + task=received_task, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), + text=extract_text_from_task(received_task), ) async def set_push_config( self, task_id: str, - url: str, - token: str | None = None, + webhook_url: str, + authentication_token: str | None = None, ) -> TaskPushNotificationConfig: """Set push notification configuration for a task. Args: task_id: ID of the task - url: Webhook URL to receive notifications - token: Optional authentication token + webhook_url: Webhook URL to receive notifications + authentication_token: Optional authentication token Returns: The created push notification configuration """ - client = await self._get_client() - - from a2a.types import PushNotificationConfig + client = await self._get_or_create_client() - config = TaskPushNotificationConfig( + push_config = TaskPushNotificationConfig( task_id=task_id, push_notification_config=PushNotificationConfig( - url=url, - token=token, + url=webhook_url, + token=authentication_token, ), ) - log.info("Setting push config for task %s: %s", task_id, url) + logger.info("Setting push config for task %s: %s", task_id, webhook_url) - return await client.set_task_callback(config) + return await client.set_task_callback(push_config) async def get_push_config( self, @@ -504,12 +533,12 @@ async def get_push_config( Returns: The push notification configuration """ - client = await self._get_client() + client = await self._get_or_create_client() - params = GetTaskPushNotificationConfigParams( + query_params = GetTaskPushNotificationConfigParams( id=task_id, push_notification_config_id=config_id, ) - log.info("Getting push config %s for task %s", config_id, task_id) + logger.info("Getting push config %s for task %s", config_id, task_id) - return await client.get_task_callback(params) + return await client.get_task_callback(query_params) diff --git a/src/a2a_handler/session.py b/src/a2a_handler/session.py index ce89967..a8128a9 100644 --- a/src/a2a_handler/session.py +++ b/src/a2a_handler/session.py @@ -10,10 +10,10 @@ from a2a_handler.common import get_logger -log = get_logger(__name__) +logger = get_logger(__name__) -DEFAULT_SESSION_DIR = Path.home() / ".handler" -SESSION_FILE = "sessions.json" +DEFAULT_SESSION_DIRECTORY = Path.home() / ".handler" +SESSION_FILENAME = "sessions.json" @dataclass @@ -41,64 +41,72 @@ class SessionStore: """Persistent store for agent sessions.""" sessions: dict[str, AgentSession] = field(default_factory=dict) - session_dir: Path = field(default_factory=lambda: DEFAULT_SESSION_DIR) + session_directory: Path = field(default_factory=lambda: DEFAULT_SESSION_DIRECTORY) @property - def session_file(self) -> Path: + def session_file_path(self) -> Path: """Path to the session file.""" - return self.session_dir / SESSION_FILE + return self.session_directory / SESSION_FILENAME - def _ensure_dir(self) -> None: + def _ensure_directory_exists(self) -> None: """Ensure the session directory exists.""" - self.session_dir.mkdir(parents=True, exist_ok=True) + self.session_directory.mkdir(parents=True, exist_ok=True) def load(self) -> None: """Load sessions from disk.""" - if not self.session_file.exists(): - log.debug("No session file found at %s", self.session_file) + if not self.session_file_path.exists(): + logger.debug("No session file found at %s", self.session_file_path) return try: - with open(self.session_file) as f: - data = json.load(f) - - for url, session_data in data.items(): - self.sessions[url] = AgentSession( - agent_url=url, - context_id=session_data.get("context_id"), - task_id=session_data.get("task_id"), + with open(self.session_file_path) as session_file: + session_data = json.load(session_file) + + for agent_url, agent_session_data in session_data.items(): + self.sessions[agent_url] = AgentSession( + agent_url=agent_url, + context_id=agent_session_data.get("context_id"), + task_id=agent_session_data.get("task_id"), ) - log.debug( - "Loaded %d sessions from %s", len(self.sessions), self.session_file + + logger.debug( + "Loaded %d sessions from %s", + len(self.sessions), + self.session_file_path, ) - except json.JSONDecodeError as e: - log.warning("Failed to parse session file: %s", e) - except OSError as e: - log.warning("Failed to read session file: %s", e) + except json.JSONDecodeError as error: + logger.warning("Failed to parse session file: %s", error) + except OSError as error: + logger.warning("Failed to read session file: %s", error) def save(self) -> None: """Save sessions to disk.""" - self._ensure_dir() + self._ensure_directory_exists() - data: dict[str, Any] = {} - for url, session in self.sessions.items(): - data[url] = { - "context_id": session.context_id, - "task_id": session.task_id, + session_data: dict[str, Any] = {} + for agent_url, agent_session in self.sessions.items(): + session_data[agent_url] = { + "context_id": agent_session.context_id, + "task_id": agent_session.task_id, } try: - with open(self.session_file, "w") as f: - json.dump(data, f, indent=2) - log.debug("Saved %d sessions to %s", len(self.sessions), self.session_file) - except OSError as e: - log.warning("Failed to write session file: %s", e) + with open(self.session_file_path, "w") as session_file: + json.dump(session_data, session_file, indent=2) + logger.debug( + "Saved %d sessions to %s", + len(self.sessions), + self.session_file_path, + ) + except OSError as error: + logger.warning("Failed to write session file: %s", error) def get(self, agent_url: str) -> AgentSession: """Get or create a session for an agent URL.""" if agent_url not in self.sessions: self.sessions[agent_url] = AgentSession(agent_url=agent_url) + logger.debug("Created new session for %s", agent_url) return self.sessions[agent_url] def update( @@ -108,10 +116,16 @@ def update( task_id: str | None = None, ) -> AgentSession: """Update session for an agent and save.""" - session = self.get(agent_url) - session.update(context_id, task_id) + agent_session = self.get(agent_url) + agent_session.update(context_id, task_id) self.save() - return session + logger.debug( + "Updated session for %s: context_id=%s, task_id=%s", + agent_url, + context_id, + task_id, + ) + return agent_session def clear(self, agent_url: str | None = None) -> None: """Clear session(s). @@ -123,10 +137,11 @@ def clear(self, agent_url: str | None = None) -> None: if agent_url: if agent_url in self.sessions: del self.sessions[agent_url] - log.info("Cleared session for %s", agent_url) + logger.info("Cleared session for %s", agent_url) else: + session_count = len(self.sessions) self.sessions.clear() - log.info("Cleared all sessions") + logger.info("Cleared all %d sessions", session_count) self.save() def list_all(self) -> list[AgentSession]: @@ -134,16 +149,17 @@ def list_all(self) -> list[AgentSession]: return list(self.sessions.values()) -_store: SessionStore | None = None +_global_session_store: SessionStore | None = None def get_session_store() -> SessionStore: """Get the global session store (singleton).""" - global _store - if _store is None: - _store = SessionStore() - _store.load() - return _store + global _global_session_store + if _global_session_store is None: + _global_session_store = SessionStore() + _global_session_store.load() + logger.debug("Initialized global session store") + return _global_session_store def get_session(agent_url: str) -> AgentSession: diff --git a/src/a2a_handler/tui.py b/src/a2a_handler/tui.py deleted file mode 100644 index 0b34b5e..0000000 --- a/src/a2a_handler/tui.py +++ /dev/null @@ -1,240 +0,0 @@ -import logging -import uuid -from typing import Any - -import httpx -from a2a.types import AgentCard -from importlib.metadata import version - -__version__ = version("a2a-handler") -from a2a_handler.client import ( - build_http_client, - fetch_agent_card, - send_message_to_agent, -) -from textual import on -from textual.app import App, ComposeResult -from textual.binding import Binding -from textual.containers import Container, Vertical -from textual.logging import TextualHandler -from textual.widgets import Button, Input - -from a2a_handler.components import ( - AgentCardPanel, - ContactPanel, - Footer, - InputPanel, - MessagesPanel, -) - -logging.basicConfig( - level="NOTSET", - handlers=[TextualHandler()], -) -logger = logging.getLogger(__name__) - - -class HandlerTUI(App[Any]): - """Handler - A2A Agent Management Interface.""" - - CSS_PATH = "tui.tcss" - - BINDINGS = [ - Binding("ctrl+q", "quit", "Quit", show=True), - Binding("ctrl+c", "clear_chat", "Clear", show=True), - Binding("ctrl+p", "command_palette", "Palette", show=True), - ] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.agent_card: AgentCard | None = None - self.http_client: httpx.AsyncClient | None = None - self.context_id: str | None = None - self.agent_url: str | None = None - - def compose(self) -> ComposeResult: - with Container(id="root-container"): - with Vertical(id="left-pane"): - yield ContactPanel(id="contact-container", classes="panel") - yield AgentCardPanel(id="agent-card-container", classes="panel") - - with Vertical(id="right-pane"): - yield MessagesPanel(id="messages-container", classes="panel") - yield InputPanel(id="input-container", classes="panel") - - yield Footer(id="footer") - - async def on_mount(self) -> None: - logger.info("TUI application starting") - self.http_client = build_http_client() - self.theme = "gruvbox" - - root = self.query_one("#root-container", Container) - root.border_title = f"Handler v{__version__} [red]●[/red] Disconnected" - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("Welcome! Connect to an agent to start chatting.") - - def watch_theme(self, theme: str) -> None: - """Called when the app theme changes.""" - logger.debug("Theme changed to: %s", theme) - self.query_one("#agent-card-container", AgentCardPanel).refresh_theme() - - @on(Button.Pressed, "#footer-quit") - async def action_quit_app(self) -> None: - await self.action_quit() - - @on(Button.Pressed, "#footer-clear") - async def action_clear_chat_footer(self) -> None: - await self.action_clear_chat() - - @on(Button.Pressed, "#footer-palette") - def action_open_command_palette(self) -> None: - self.action_command_palette() - - async def _connect(self, agent_url: str) -> AgentCard: - if not self.http_client: - raise RuntimeError("HTTP client not initialized") - - logger.info("Connecting to agent at %s", agent_url) - return await fetch_agent_card(agent_url, self.http_client) - - def _update_ui_connected(self, card: AgentCard) -> None: - root = self.query_one("#root-container", Container) - root.border_title = ( - f"Handler v{__version__} [green]●[/green] Connected: {card.name}" - ) - - self.query_one("#agent-card-container", AgentCardPanel).update_card(card) - self.query_one("#contact-container", ContactPanel).set_connected(True) - self.query_one("#messages-container", MessagesPanel).update_message_count() - - def _update_ui_disconnected(self) -> None: - root = self.query_one("#root-container", Container) - root.border_title = f"Handler v{__version__} [red]●[/red] Disconnected" - - self.query_one("#agent-card-container", AgentCardPanel).update_card(None) - self.query_one("#contact-container", ContactPanel).set_connected(False) - - @on(Button.Pressed, "#connect-btn") - async def connect_to_agent(self) -> None: - contact = self.query_one("#contact-container", ContactPanel) - agent_url = contact.get_url() - - if not agent_url: - logger.warning("Connect attempted with empty URL") - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Please enter an agent URL") - return - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message(f"Connecting to {agent_url}...") - - try: - card = await self._connect(agent_url) - - self.agent_card = card - self.agent_url = agent_url - self.context_id = str(uuid.uuid4()) - - logger.info("Successfully connected to %s", card.name) - - self._update_ui_connected(card) - messages.add_system_message(f"✓ Connected to {card.name}") - self.query_one("#agent-card-container", AgentCardPanel).focus() - - except Exception as e: - logger.error("Connection failed: %s", e, exc_info=True) - messages.add_system_message(f"✗ Connection failed: {str(e)}") - - @on(Button.Pressed, "#disconnect-btn") - def disconnect_from_agent(self) -> None: - logger.info("Disconnecting from %s", self.agent_url) - self.agent_card = None - self.agent_url = None - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("Disconnected") - - self._update_ui_disconnected() - - @on(Input.Submitted, "#message-input") - async def send_on_enter(self) -> None: - if self.agent_url: - await self._send_message() - else: - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Not connected to an agent") - - @on(Button.Pressed, "#send-btn") - async def send_button_pressed(self) -> None: - if self.agent_url: - await self._send_message() - else: - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Not connected to an agent") - - async def _handle_agent_response(self, response_data: dict[str, Any]) -> str: - if not response_data: - return "Error: No result in response" - - texts = [] - if "parts" in response_data: - texts.extend(p.get("text", "") for p in response_data["parts"]) - - for artifact in response_data.get("artifacts", []): - texts.extend(p.get("text", "") for p in artifact.get("parts", [])) - - return "\n".join(t for t in texts if t) or "No text in response" - - async def _send_message(self) -> None: - if not self.agent_url or not self.http_client: - logger.warning("Attempted to send message without connection") - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Not connected to an agent") - return - - input_panel = self.query_one("#input-container", InputPanel) - message_text = input_panel.get_message() - - if not message_text: - return - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_message("user", message_text) - - try: - logger.info("Sending message: %s", message_text[:50]) - - response_data = await send_message_to_agent( - self.agent_url, - message_text, - self.http_client, - context_id=self.context_id, - ) - - response_text = await self._handle_agent_response(response_data) - messages.add_message("agent", response_text) - - except Exception as e: - logger.error("Error sending message: %s", e, exc_info=True) - messages.add_system_message(f"✗ Error: {str(e)}") - - async def action_clear_chat(self) -> None: - messages = self.query_one("#messages-container", MessagesPanel) - await messages.clear() - - async def on_unmount(self) -> None: - logger.info("Shutting down TUI application") - if self.http_client: - await self.http_client.aclose() - - -def main() -> None: - """Entry point for the TUI application.""" - app = HandlerTUI() - app.run() - - -if __name__ == "__main__": - main() diff --git a/src/a2a_handler/tui/__init__.py b/src/a2a_handler/tui/__init__.py new file mode 100644 index 0000000..560dc7b --- /dev/null +++ b/src/a2a_handler/tui/__init__.py @@ -0,0 +1,5 @@ +"""Handler TUI application.""" + +from a2a_handler.tui.app import HandlerTUI, main + +__all__ = ["HandlerTUI", "main"] diff --git a/src/a2a_handler/tui/app.py b/src/a2a_handler/tui/app.py new file mode 100644 index 0000000..7e1e5c2 --- /dev/null +++ b/src/a2a_handler/tui/app.py @@ -0,0 +1,252 @@ +"""Handler TUI application.""" + +import logging +import uuid +from importlib.metadata import version +from typing import Any + +import httpx +from a2a.types import AgentCard +from textual import on +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Container, Vertical +from textual.logging import TextualHandler +from textual.widgets import Button, Input + +from a2a_handler.service import A2AService +from a2a_handler.tui.components import ( + AgentCardPanel, + ContactPanel, + Footer, + InputPanel, + MessagesPanel, +) + +__version__ = version("a2a-handler") + +logging.basicConfig( + level="NOTSET", + handlers=[TextualHandler()], +) +logger = logging.getLogger(__name__) + +DEFAULT_HTTP_TIMEOUT_SECONDS = 120 + + +def build_http_client( + timeout_seconds: int = DEFAULT_HTTP_TIMEOUT_SECONDS, +) -> httpx.AsyncClient: + """Build an HTTP client with the specified timeout.""" + return httpx.AsyncClient(timeout=timeout_seconds) + + +class HandlerTUI(App[Any]): + """Handler - A2A Agent Management Interface.""" + + CSS_PATH = "app.tcss" + + BINDINGS = [ + Binding("ctrl+q", "quit", "Quit", show=True), + Binding("ctrl+c", "clear_chat", "Clear", show=True), + Binding("ctrl+p", "command_palette", "Palette", show=True), + ] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.current_agent_card: AgentCard | None = None + self.http_client: httpx.AsyncClient | None = None + self.current_context_id: str | None = None + self.current_agent_url: str | None = None + self._agent_service: A2AService | None = None + + def compose(self) -> ComposeResult: + with Container(id="root-container"): + with Vertical(id="left-pane"): + yield ContactPanel(id="contact-container", classes="panel") + yield AgentCardPanel(id="agent-card-container", classes="panel") + + with Vertical(id="right-pane"): + yield MessagesPanel(id="messages-container", classes="panel") + yield InputPanel(id="input-container", classes="panel") + + yield Footer(id="footer") + + async def on_mount(self) -> None: + logger.info("TUI application starting") + self.http_client = build_http_client() + self.theme = "gruvbox" + + root_container = self.query_one("#root-container", Container) + root_container.border_title = ( + f"Handler v{__version__} [red]●[/red] Disconnected" + ) + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message( + "Welcome! Connect to an agent to start chatting." + ) + + def watch_theme(self, new_theme: str) -> None: + """Called when the app theme changes.""" + logger.debug("Theme changed to: %s", new_theme) + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.refresh_theme() + + @on(Button.Pressed, "#footer-quit") + async def handle_quit_button(self) -> None: + await self.action_quit() + + @on(Button.Pressed, "#footer-clear") + async def handle_clear_button(self) -> None: + await self.action_clear_chat() + + @on(Button.Pressed, "#footer-palette") + def handle_palette_button(self) -> None: + self.action_command_palette() + + async def _connect_to_agent(self, agent_url: str) -> AgentCard: + if not self.http_client: + raise RuntimeError("HTTP client not initialized") + + logger.info("Connecting to agent at %s", agent_url) + self._agent_service = A2AService(self.http_client, agent_url) + return await self._agent_service.get_card() + + def _update_ui_for_connected_state(self, agent_card: AgentCard) -> None: + root_container = self.query_one("#root-container", Container) + root_container.border_title = ( + f"Handler v{__version__} [green]●[/green] Connected: {agent_card.name}" + ) + + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.update_card(agent_card) + + contact_panel = self.query_one("#contact-container", ContactPanel) + contact_panel.set_connected(True) + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.update_message_count() + + def _update_ui_for_disconnected_state(self) -> None: + root_container = self.query_one("#root-container", Container) + root_container.border_title = ( + f"Handler v{__version__} [red]●[/red] Disconnected" + ) + + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.update_card(None) + + contact_panel = self.query_one("#contact-container", ContactPanel) + contact_panel.set_connected(False) + + @on(Button.Pressed, "#connect-btn") + async def handle_connect_button(self) -> None: + contact_panel = self.query_one("#contact-container", ContactPanel) + agent_url = contact_panel.get_url() + + if not agent_url: + logger.warning("Connect attempted with empty URL") + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Please enter an agent URL") + return + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message(f"Connecting to {agent_url}...") + + try: + agent_card = await self._connect_to_agent(agent_url) + + self.current_agent_card = agent_card + self.current_agent_url = agent_url + self.current_context_id = str(uuid.uuid4()) + + logger.info("Successfully connected to %s", agent_card.name) + + self._update_ui_for_connected_state(agent_card) + messages_panel.add_system_message(f"✓ Connected to {agent_card.name}") + + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.focus() + + except Exception as error: + logger.error("Connection failed: %s", error, exc_info=True) + messages_panel.add_system_message(f"✗ Connection failed: {error!s}") + + @on(Button.Pressed, "#disconnect-btn") + def handle_disconnect_button(self) -> None: + logger.info("Disconnecting from %s", self.current_agent_url) + self.current_agent_card = None + self.current_agent_url = None + self._agent_service = None + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("Disconnected") + + self._update_ui_for_disconnected_state() + + @on(Input.Submitted, "#message-input") + async def handle_message_submit(self) -> None: + if self.current_agent_url: + await self._send_message() + else: + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Not connected to an agent") + + @on(Button.Pressed, "#send-btn") + async def handle_send_button(self) -> None: + if self.current_agent_url: + await self._send_message() + else: + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Not connected to an agent") + + async def _send_message(self) -> None: + if not self.current_agent_url or not self._agent_service: + logger.warning("Attempted to send message without connection") + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Not connected to an agent") + return + + input_panel = self.query_one("#input-container", InputPanel) + message_text = input_panel.get_message() + + if not message_text: + return + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_message("user", message_text) + + try: + logger.info("Sending message: %s", message_text[:50]) + + send_result = await self._agent_service.send( + message_text, + context_id=self.current_context_id, + ) + + response_text = send_result.text or "No text in response" + messages_panel.add_message("agent", response_text) + + except Exception as error: + logger.error("Error sending message: %s", error, exc_info=True) + messages_panel.add_system_message(f"✗ Error: {error!s}") + + async def action_clear_chat(self) -> None: + messages_panel = self.query_one("#messages-container", MessagesPanel) + await messages_panel.clear() + + async def on_unmount(self) -> None: + logger.info("Shutting down TUI application") + if self.http_client: + await self.http_client.aclose() + + +def main() -> None: + """Entry point for the TUI application.""" + application = HandlerTUI() + application.run() + + +if __name__ == "__main__": + main() diff --git a/src/a2a_handler/tui.tcss b/src/a2a_handler/tui/app.tcss similarity index 100% rename from src/a2a_handler/tui.tcss rename to src/a2a_handler/tui/app.tcss diff --git a/src/a2a_handler/components/__init__.py b/src/a2a_handler/tui/components/__init__.py similarity index 87% rename from src/a2a_handler/components/__init__.py rename to src/a2a_handler/tui/components/__init__.py index aff261c..31884df 100644 --- a/src/a2a_handler/components/__init__.py +++ b/src/a2a_handler/tui/components/__init__.py @@ -1,4 +1,4 @@ -from .agent_card import AgentCardPanel +from .card import AgentCardPanel from .contact import ContactPanel from .footer import Footer from .input import InputPanel diff --git a/src/a2a_handler/tui/components/card.py b/src/a2a_handler/tui/components/card.py new file mode 100644 index 0000000..a60059f --- /dev/null +++ b/src/a2a_handler/tui/components/card.py @@ -0,0 +1,266 @@ +"""Agent card panel component for displaying agent metadata.""" + +import json +import re +from typing import Any + +from a2a.types import AgentCard +from rich.syntax import Syntax +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, VerticalScroll +from textual.widgets import Static, TabbedContent, TabPane, Tabs + +from a2a_handler.common import get_logger + +logger = get_logger(__name__) + +TEXTUAL_TO_SYNTAX_THEME_MAP: dict[str, str] = { + "gruvbox": "gruvbox-dark", + "nord": "nord", + "tokyo-night": "monokai", + "textual-dark": "monokai", + "textual-light": "default", + "solarized-light": "solarized-light", + "dracula": "dracula", + "catppuccin-mocha": "monokai", + "monokai": "monokai", +} + +SHORT_VIEW_FIELDS = [ + "name", + "description", + "version", + "url", + "defaultInputModes", + "defaultOutputModes", +] + + +class AgentCardPanel(Container): + """Panel displaying agent card information with tabs.""" + + BINDINGS = [ + Binding("h", "previous_tab", "Previous Tab", show=False), + Binding("l", "next_tab", "Next Tab", show=False), + Binding("left", "previous_tab", "Previous Tab", show=False), + Binding("right", "next_tab", "Next Tab", show=False), + Binding("j", "scroll_down", "Scroll Down", show=False), + Binding("k", "scroll_up", "Scroll Up", show=False), + Binding("down", "scroll_down", "Scroll Down", show=False), + Binding("up", "scroll_up", "Scroll Up", show=False), + ] + + can_focus = True + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._current_agent_card: AgentCard | None = None + + def compose(self) -> ComposeResult: + with TabbedContent(id="agent-card-tabs"): + with TabPane("Short", id="short-tab"): + yield VerticalScroll( + Static("Not connected", id="agent-short"), + id="short-scroll", + ) + with TabPane("Long", id="long-tab"): + yield VerticalScroll( + Static("", id="agent-long"), + id="long-scroll", + ) + with TabPane("Raw", id="raw-tab"): + yield VerticalScroll( + Static("", id="agent-raw"), + id="raw-scroll", + ) + + def on_mount(self) -> None: + self.border_title = "AGENT CARD" + self.border_subtitle = "READY" + for widget in self.query("TabbedContent, Tabs, Tab, TabPane, VerticalScroll"): + widget.can_focus = False + logger.debug("Agent card panel mounted") + + def _get_syntax_theme_for_current_app_theme(self) -> str: + """Get the Rich Syntax theme name for the current app theme.""" + current_theme = self.app.theme or "" + return TEXTUAL_TO_SYNTAX_THEME_MAP.get(current_theme, "monokai") + + def _convert_key_to_sentence_case(self, key: str) -> str: + """Convert a camelCase or snake_case key to sentence case.""" + spaced_key = re.sub(r"([a-z])([A-Z])", r"\1 \2", key) + return spaced_key.replace("_", " ").capitalize() + + def _is_value_empty(self, value: Any) -> bool: + """Check if a value is truly empty, including nested structures.""" + if value is None: + return True + if isinstance(value, (str, list, dict)) and not value: + return True + if isinstance(value, dict): + return all( + self._is_value_empty(nested_value) for nested_value in value.values() + ) + if isinstance(value, list): + return all(self._is_value_empty(item) for item in value) + return False + + def _format_nested_value(self, value: Any, indentation_level: int = 0) -> str: + """Format a nested value for display with proper indentation.""" + indentation_prefix = " " * indentation_level + + if isinstance(value, dict): + formatted_lines = [] + for key, nested_value in value.items(): + if self._is_value_empty(nested_value): + continue + formatted_key = self._convert_key_to_sentence_case(key) + if isinstance(nested_value, (list, dict)): + formatted_lines.append( + f"{indentation_prefix}[bold]{formatted_key}[/]" + ) + formatted_lines.append( + self._format_nested_value(nested_value, indentation_level + 1) + ) + else: + formatted_lines.append( + f"{indentation_prefix}[bold]{formatted_key}:[/] {nested_value}" + ) + return "\n".join(formatted_lines) + + if isinstance(value, list): + formatted_lines = [] + for item in value: + if self._is_value_empty(item): + continue + if isinstance(item, dict): + formatted_lines.append( + self._format_nested_value(item, indentation_level) + ) + else: + formatted_lines.append(f"{indentation_prefix}• {item}") + return "\n".join(formatted_lines) + + return f"{indentation_prefix}{value}" + + def _build_short_view_content(self, agent_card: AgentCard) -> str: + """Build the short view with essential fields only.""" + card_data = agent_card.model_dump() + formatted_lines = [] + + for field_name in SHORT_VIEW_FIELDS: + field_value = card_data.get(field_name) + if self._is_value_empty(field_value): + continue + formatted_key = self._convert_key_to_sentence_case(field_name) + if isinstance(field_value, (list, dict)): + formatted_lines.append(f"[bold]{formatted_key}[/]") + formatted_lines.append( + self._format_nested_value(field_value, indentation_level=1) + ) + else: + formatted_lines.append(f"[bold]{formatted_key}:[/] {field_value}") + + return "\n".join(formatted_lines) + + def _build_long_view_content(self, agent_card: AgentCard) -> str: + """Build the long view with all non-empty fields.""" + card_data = agent_card.model_dump() + formatted_lines = [] + + for field_name, field_value in card_data.items(): + if self._is_value_empty(field_value): + continue + formatted_key = self._convert_key_to_sentence_case(field_name) + if isinstance(field_value, (list, dict)): + formatted_lines.append(f"[bold]{formatted_key}[/]") + formatted_lines.append( + self._format_nested_value(field_value, indentation_level=1) + ) + else: + formatted_lines.append(f"[bold]{formatted_key}:[/] {field_value}") + + return "\n".join(formatted_lines) + + def update_card(self, agent_card: AgentCard | None) -> None: + """Update the displayed agent card.""" + self._current_agent_card = agent_card + + short_view_widget = self.query_one("#agent-short", Static) + long_view_widget = self.query_one("#agent-long", Static) + raw_view_widget = self.query_one("#agent-raw", Static) + + if agent_card is None: + logger.debug("Clearing agent card display") + short_view_widget.update("Not connected") + long_view_widget.update("") + raw_view_widget.update("") + self.border_subtitle = "READY" + else: + logger.info("Displaying agent card for: %s", agent_card.name) + short_view_widget.update(self._build_short_view_content(agent_card)) + long_view_widget.update(self._build_long_view_content(agent_card)) + + json_content = json.dumps(agent_card.model_dump(), indent=2, default=str) + syntax_theme = self._get_syntax_theme_for_current_app_theme() + raw_view_widget.update(Syntax(json_content, "json", theme=syntax_theme)) + self.border_subtitle = "ACTIVE" + + def refresh_theme(self) -> None: + """Refresh the raw view syntax highlighting for theme changes.""" + if self._current_agent_card is None: + return + + logger.debug("Refreshing syntax theme for agent card raw view") + json_content = json.dumps( + self._current_agent_card.model_dump(), indent=2, default=str + ) + syntax_theme = self._get_syntax_theme_for_current_app_theme() + self.query_one("#agent-raw", Static).update( + Syntax(json_content, "json", theme=syntax_theme) + ) + + def _get_currently_active_scroll_container(self) -> VerticalScroll | None: + """Get the currently visible scroll container.""" + tabbed_content = self.query_one("#agent-card-tabs", TabbedContent) + active_tab_id = tabbed_content.active + + scroll_container_map = { + "short-tab": "#short-scroll", + "long-tab": "#long-scroll", + "raw-tab": "#raw-scroll", + } + + scroll_container_id = scroll_container_map.get(active_tab_id) + if scroll_container_id: + return self.query_one(scroll_container_id, VerticalScroll) + return None + + def action_previous_tab(self) -> None: + """Switch to the previous tab.""" + try: + tabs_widget = self.query_one("#agent-card-tabs Tabs", Tabs) + tabs_widget.action_previous_tab() + except Exception: + pass + + def action_next_tab(self) -> None: + """Switch to the next tab.""" + try: + tabs_widget = self.query_one("#agent-card-tabs Tabs", Tabs) + tabs_widget.action_next_tab() + except Exception: + pass + + def action_scroll_down(self) -> None: + """Scroll down in the active tab's scroll container.""" + scroll_container = self._get_currently_active_scroll_container() + if scroll_container: + scroll_container.scroll_down() + + def action_scroll_up(self) -> None: + """Scroll up in the active tab's scroll container.""" + scroll_container = self._get_currently_active_scroll_container() + if scroll_container: + scroll_container.scroll_up() diff --git a/src/a2a_handler/components/contact.py b/src/a2a_handler/tui/components/contact.py similarity index 54% rename from src/a2a_handler/components/contact.py rename to src/a2a_handler/tui/components/contact.py index 4fad077..da8612c 100644 --- a/src/a2a_handler/components/contact.py +++ b/src/a2a_handler/tui/components/contact.py @@ -1,10 +1,12 @@ -import logging +"""Contact panel component for agent connection management.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Button, Input, Label -logger = logging.getLogger(__name__) +from a2a_handler.common import get_logger + +logger = get_logger(__name__) class ContactPanel(Container): @@ -23,12 +25,19 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.border_title = "CONTACT" + logger.debug("Contact panel mounted") - def set_connected(self, connected: bool) -> None: + def set_connected(self, is_connected: bool) -> None: """Update button states based on connection status.""" - self.query_one("#connect-btn", Button).disabled = connected - self.query_one("#disconnect-btn", Button).disabled = not connected + connect_button = self.query_one("#connect-btn", Button) + disconnect_button = self.query_one("#disconnect-btn", Button) + + connect_button.disabled = is_connected + disconnect_button.disabled = not is_connected + + logger.debug("Connection state updated: connected=%s", is_connected) def get_url(self) -> str: - """Get the current agent URL.""" - return self.query_one("#agent-url", Input).value.strip() + """Get the current agent URL from the input field.""" + url_input = self.query_one("#agent-url", Input) + return url_input.value.strip() diff --git a/src/a2a_handler/components/footer.py b/src/a2a_handler/tui/components/footer.py similarity index 90% rename from src/a2a_handler/components/footer.py rename to src/a2a_handler/tui/components/footer.py index b428611..b7cd5d4 100644 --- a/src/a2a_handler/components/footer.py +++ b/src/a2a_handler/tui/components/footer.py @@ -1,3 +1,5 @@ +"""Footer component with keyboard shortcut buttons.""" + from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Button diff --git a/src/a2a_handler/components/input.py b/src/a2a_handler/tui/components/input.py similarity index 74% rename from src/a2a_handler/components/input.py rename to src/a2a_handler/tui/components/input.py index 8a3508a..0884926 100644 --- a/src/a2a_handler/components/input.py +++ b/src/a2a_handler/tui/components/input.py @@ -1,10 +1,12 @@ -import logging +"""Input panel component for message composition.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Button, Input -logger = logging.getLogger(__name__) +from a2a_handler.common import get_logger + +logger = get_logger(__name__) class InputPanel(Container): @@ -19,14 +21,15 @@ def on_mount(self) -> None: self.border_title = "INPUT" self.border_subtitle = "PRESS ENTER TO SEND" self.query_one("#send-btn", Button).can_focus = False + logger.debug("Input panel mounted") def get_message(self) -> str: """Get and clear the current message input.""" message_input = self.query_one("#message-input", Input) - message = message_input.value.strip() + message_text = message_input.value.strip() message_input.value = "" - return message + return message_text def focus_input(self) -> None: - """Focus the message input.""" + """Focus the message input field.""" self.query_one("#message-input", Input).focus() diff --git a/src/a2a_handler/components/messages.py b/src/a2a_handler/tui/components/messages.py similarity index 59% rename from src/a2a_handler/components/messages.py rename to src/a2a_handler/tui/components/messages.py index 113ba21..1a3527d 100644 --- a/src/a2a_handler/components/messages.py +++ b/src/a2a_handler/tui/components/messages.py @@ -1,4 +1,5 @@ -import logging +"""Messages panel component for chat display.""" + from datetime import datetime from typing import Any @@ -7,7 +8,9 @@ from textual.containers import Container, Vertical, VerticalScroll from textual.widgets import Static -logger = logging.getLogger(__name__) +from a2a_handler.common import get_logger + +logger = get_logger(__name__) class Message(Vertical): @@ -26,16 +29,16 @@ def __init__( self.timestamp = timestamp or datetime.now() def compose(self) -> ComposeResult: - time_str = self.timestamp.strftime("%H:%M:%S") + formatted_time = self.timestamp.strftime("%H:%M:%S") if self.role == "system": - yield Static(f"[dim]{time_str}[/dim] [italic]{self.text}[/italic]") + yield Static(f"[dim]{formatted_time}[/dim] [italic]{self.text}[/italic]") else: role_color = "#88c0d0" if self.role == "agent" else "#bf616a" - yield Static(f"[dim]{time_str}[/dim] [{role_color}]{self.text}[/]") + yield Static(f"[dim]{formatted_time}[/dim] [{role_color}]{self.text}[/]") -class ChatScroll(VerticalScroll): +class ChatScrollContainer(VerticalScroll): """Scrollable chat area.""" can_focus = False @@ -54,37 +57,39 @@ class MessagesPanel(Container): can_focus = True def compose(self) -> ComposeResult: - yield ChatScroll(id="chat") + yield ChatScrollContainer(id="chat") def on_mount(self) -> None: self.border_title = "MESSAGES" + logger.debug("Messages panel mounted") - def _get_chat(self) -> ChatScroll: - return self.query_one("#chat", ChatScroll) + def _get_chat_container(self) -> ChatScrollContainer: + return self.query_one("#chat", ChatScrollContainer) def add_message(self, role: str, content: str) -> None: logger.debug("Adding %s message: %s", role, content[:50]) - chat = self._get_chat() - message = Message(role, content) - chat.mount(message) - chat.scroll_end(animate=False) + chat_container = self._get_chat_container() + message_widget = Message(role, content) + chat_container.mount(message_widget) + chat_container.scroll_end(animate=False) def add_system_message(self, content: str) -> None: logger.info("System message: %s", content) self.add_message("system", content) def update_message_count(self) -> None: - chat = self._get_chat() - self.border_subtitle = f"{len(chat.children)} MESSAGES" + chat_container = self._get_chat_container() + message_count = len(chat_container.children) + self.border_subtitle = f"{message_count} MESSAGES" async def clear(self) -> None: - logger.info("Clearing chat") - chat = self._get_chat() - await chat.remove_children() + logger.info("Clearing chat messages") + chat_container = self._get_chat_container() + await chat_container.remove_children() self.add_system_message("Chat cleared") def action_scroll_down(self) -> None: - self._get_chat().scroll_down() + self._get_chat_container().scroll_down() def action_scroll_up(self) -> None: - self._get_chat().scroll_up() + self._get_chat_container().scroll_up() diff --git a/src/a2a_handler/validation.py b/src/a2a_handler/validation.py index 292c48f..6d876c0 100644 --- a/src/a2a_handler/validation.py +++ b/src/a2a_handler/validation.py @@ -12,7 +12,7 @@ from a2a_handler.common import get_logger -log = get_logger(__name__) +logger = get_logger(__name__) class ValidationSource(Enum): @@ -26,12 +26,12 @@ class ValidationSource(Enum): class ValidationIssue: """Represents a single validation issue.""" - field: str + field_name: str message: str issue_type: str = "error" def __str__(self) -> str: - return f"[{self.issue_type}] {self.field}: {self.message}" + return f"[{self.issue_type}] {self.field_name}: {self.message}" @dataclass @@ -65,24 +65,26 @@ def protocol_version(self) -> str: return "Unknown" -def _parse_pydantic_error(error: ValidationError) -> list[ValidationIssue]: +def parse_pydantic_validation_errors( + validation_error: ValidationError, +) -> list[ValidationIssue]: """Parse Pydantic validation errors into ValidationIssues.""" - issues = [] - for err in error.errors(): - field_path = ".".join(str(loc) for loc in err["loc"]) - message = err["msg"] - issue_type = err["type"] - issues.append( + validation_issues = [] + for error_detail in validation_error.errors(): + field_path = ".".join(str(location) for location in error_detail["loc"]) + error_message = error_detail["msg"] + error_type = error_detail["type"] + validation_issues.append( ValidationIssue( - field=field_path or "root", - message=message, - issue_type=issue_type, + field_name=field_path or "root", + message=error_message, + issue_type=error_type, ) ) - return issues + return validation_issues -def _check_best_practices(card: AgentCard) -> list[ValidationIssue]: +def check_agent_card_best_practices(agent_card: AgentCard) -> list[ValidationIssue]: """Check for best practices and generate warnings. Note: In A2A v0.3.0, the following are REQUIRED fields and validated by Pydantic: @@ -92,192 +94,197 @@ def _check_best_practices(card: AgentCard) -> list[ValidationIssue]: This function only warns about optional fields that improve agent discoverability. """ - warnings = [] + best_practice_warnings = [] - if not card.provider: - warnings.append( + if not agent_card.provider: + best_practice_warnings.append( ValidationIssue( - field="provider", + field_name="provider", message="Agent card should specify a provider for better discoverability", issue_type="warning", ) ) - if not card.documentation_url: - warnings.append( + if not agent_card.documentation_url: + best_practice_warnings.append( ValidationIssue( - field="documentationUrl", + field_name="documentationUrl", message="Agent card should include documentation URL", issue_type="warning", ) ) - if not card.icon_url: - warnings.append( + if not agent_card.icon_url: + best_practice_warnings.append( ValidationIssue( - field="iconUrl", + field_name="iconUrl", message="Agent card should include an icon URL for UI display", issue_type="warning", ) ) - if card.skills: - for i, skill in enumerate(card.skills): + if agent_card.skills: + for skill_index, skill in enumerate(agent_card.skills): if not skill.description: - warnings.append( + best_practice_warnings.append( ValidationIssue( - field=f"skills[{i}].description", + field_name=f"skills[{skill_index}].description", message=f"Skill '{skill.name}' should have a description", issue_type="warning", ) ) if not skill.examples or len(skill.examples) == 0: - warnings.append( + best_practice_warnings.append( ValidationIssue( - field=f"skills[{i}].examples", + field_name=f"skills[{skill_index}].examples", message=f"Skill '{skill.name}' should include example prompts", issue_type="warning", ) ) - if not card.additional_interfaces or len(card.additional_interfaces) == 0: - warnings.append( + if ( + not agent_card.additional_interfaces + or len(agent_card.additional_interfaces) == 0 + ): + best_practice_warnings.append( ValidationIssue( - field="additionalInterfaces", + field_name="additionalInterfaces", message="Consider declaring additional transport interfaces for flexibility", issue_type="warning", ) ) - return warnings + return best_practice_warnings def validate_agent_card_data( - data: dict[str, Any], source: str, source_type: ValidationSource + card_data: dict[str, Any], + source: str, + source_type: ValidationSource, ) -> ValidationResult: """Validate agent card data against the A2A protocol schema. Args: - data: Raw agent card data as a dictionary + card_data: Raw agent card data as a dictionary source: The source (URL or file path) of the data source_type: Whether the source is a URL or file Returns: ValidationResult with validation status and any issues """ - log.debug("Validating agent card data from %s", source) + logger.debug("Validating agent card data from %s", source) try: - card = AgentCard.model_validate(data) - log.info("Agent card validation successful for %s", card.name) + agent_card = AgentCard.model_validate(card_data) + logger.info("Agent card validation successful for %s", agent_card.name) - warnings = _check_best_practices(card) + best_practice_warnings = check_agent_card_best_practices(agent_card) return ValidationResult( valid=True, source=source, source_type=source_type, - agent_card=card, - warnings=warnings, - raw_data=data, + agent_card=agent_card, + warnings=best_practice_warnings, + raw_data=card_data, ) - except ValidationError as e: - log.warning("Agent card validation failed: %s", e) - issues = _parse_pydantic_error(e) + except ValidationError as validation_error: + logger.warning("Agent card validation failed: %s", validation_error) + validation_issues = parse_pydantic_validation_errors(validation_error) return ValidationResult( valid=False, source=source, source_type=source_type, - issues=issues, - raw_data=data, + issues=validation_issues, + raw_data=card_data, ) async def validate_agent_card_from_url( - url: str, - client: httpx.AsyncClient | None = None, - card_path: str | None = None, + agent_url: str, + http_client: httpx.AsyncClient | None = None, + agent_card_path: str | None = None, ) -> ValidationResult: """Fetch and validate an agent card from a URL. Args: - url: The base URL of the agent - client: Optional HTTP client to use - card_path: Optional custom path to the agent card (default: /.well-known/agent.json) + agent_url: The base URL of the agent + http_client: Optional HTTP client to use + agent_card_path: Optional custom path to the agent card (default: /.well-known/agent.json) Returns: ValidationResult with validation status and any issues """ - log.info("Validating agent card from URL: %s", url) + logger.info("Validating agent card from URL: %s", agent_url) - should_close = client is None - if client is None: - client = httpx.AsyncClient(timeout=30) + should_close_client = http_client is None + if http_client is None: + http_client = httpx.AsyncClient(timeout=30) try: - base_url = url.rstrip("/") - if card_path: - full_url = f"{base_url}/{card_path.lstrip('/')}" + base_url = agent_url.rstrip("/") + if agent_card_path: + full_url = f"{base_url}/{agent_card_path.lstrip('/')}" else: full_url = f"{base_url}/.well-known/agent-card.json" - log.debug("Fetching agent card from %s", full_url) - response = await client.get(full_url) + logger.debug("Fetching agent card from %s", full_url) + response = await http_client.get(full_url) response.raise_for_status() - data = response.json() - return validate_agent_card_data(data, url, ValidationSource.URL) + card_data = response.json() + return validate_agent_card_data(card_data, agent_url, ValidationSource.URL) - except httpx.HTTPStatusError as e: - log.error("HTTP error fetching agent card: %s", e) + except httpx.HTTPStatusError as http_error: + logger.error("HTTP error fetching agent card: %s", http_error) return ValidationResult( valid=False, - source=url, + source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field="http", - message=f"HTTP {e.response.status_code}: {e.response.text[:200]}", + field_name="http", + message=f"HTTP {http_error.response.status_code}: {http_error.response.text[:200]}", issue_type="http_error", ) ], ) - except httpx.RequestError as e: - log.error("Request error fetching agent card: %s", e) + except httpx.RequestError as request_error: + logger.error("Request error fetching agent card: %s", request_error) return ValidationResult( valid=False, - source=url, + source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field="connection", - message=str(e), + field_name="connection", + message=str(request_error), issue_type="connection_error", ) ], ) - except json.JSONDecodeError as e: - log.error("JSON decode error: %s", e) + except json.JSONDecodeError as json_error: + logger.error("JSON decode error: %s", json_error) return ValidationResult( valid=False, - source=url, + source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field="json", - message=f"Invalid JSON: {e}", + field_name="json", + message=f"Invalid JSON: {json_error}", issue_type="json_error", ) ], ) finally: - if should_close: - await client.aclose() + if should_close_client: + await http_client.aclose() def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: @@ -290,17 +297,17 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ValidationResult with validation status and any issues """ path = Path(file_path) - log.info("Validating agent card from file: %s", path) + logger.info("Validating agent card from file: %s", path) if not path.exists(): - log.error("File not found: %s", path) + logger.error("File not found: %s", path) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=f"File not found: {path}", issue_type="file_error", ) @@ -308,14 +315,14 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ) if not path.is_file(): - log.error("Path is not a file: %s", path) + logger.error("Path is not a file: %s", path) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=f"Path is not a file: {path}", issue_type="file_error", ) @@ -323,51 +330,51 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ) try: - with open(path, encoding="utf-8") as f: - data = json.load(f) + with open(path, encoding="utf-8") as card_file: + card_data = json.load(card_file) - return validate_agent_card_data(data, str(path), ValidationSource.FILE) + return validate_agent_card_data(card_data, str(path), ValidationSource.FILE) - except json.JSONDecodeError as e: - log.error("JSON decode error: %s", e) + except json.JSONDecodeError as json_error: + logger.error("JSON decode error: %s", json_error) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="json", - message=f"Invalid JSON at line {e.lineno}, column {e.colno}: {e.msg}", + field_name="json", + message=f"Invalid JSON at line {json_error.lineno}, column {json_error.colno}: {json_error.msg}", issue_type="json_error", ) ], ) except PermissionError: - log.error("Permission denied reading file: %s", path) + logger.error("Permission denied reading file: %s", path) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=f"Permission denied: {path}", issue_type="file_error", ) ], ) - except OSError as e: - log.error("Error reading file: %s", e) + except OSError as os_error: + logger.error("Error reading file: %s", os_error) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", - message=str(e), + field_name="file", + message=str(os_error), issue_type="file_error", ) ], diff --git a/src/a2a_handler/webhook.py b/src/a2a_handler/webhook.py new file mode 100644 index 0000000..268aabe --- /dev/null +++ b/src/a2a_handler/webhook.py @@ -0,0 +1,177 @@ +"""Local webhook server for receiving A2A push notifications. + +This module provides a simple HTTP server that can receive push notifications +from A2A agents for testing purposes. +""" + +import json +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from a2a_handler.common import console, get_logger + +logger = get_logger(__name__) + +DEFAULT_MAX_STORED_NOTIFICATIONS = 100 + + +@dataclass +class PushNotification: + """A received push notification.""" + + timestamp: datetime + task_id: str | None + payload: dict[str, Any] + headers: dict[str, str] + + +@dataclass +class PushNotificationStore: + """In-memory store for received notifications.""" + + notifications: deque[PushNotification] = field( + default_factory=lambda: deque(maxlen=DEFAULT_MAX_STORED_NOTIFICATIONS) + ) + + def add_notification(self, notification: PushNotification) -> None: + """Add a notification to the store.""" + self.notifications.append(notification) + logger.debug( + "Stored notification for task: %s (total: %d)", + notification.task_id, + len(self.notifications), + ) + + def get_all_notifications(self) -> list[PushNotification]: + """Get all stored notifications.""" + return list(self.notifications) + + def clear_all_notifications(self) -> None: + """Clear all stored notifications.""" + notification_count = len(self.notifications) + self.notifications.clear() + logger.info("Cleared %d stored notifications", notification_count) + + +notification_store = PushNotificationStore() + + +async def handle_push_notification(request: Request) -> JSONResponse: + """Handle incoming push notifications from A2A agents.""" + try: + request_payload = await request.json() + except json.JSONDecodeError: + logger.warning("Received invalid JSON in push notification") + return JSONResponse({"error": "Invalid JSON"}, status_code=400) + + request_headers = dict(request.headers) + task_id = request_payload.get("id") or request_payload.get("task_id") + + notification = PushNotification( + timestamp=datetime.now(), + task_id=task_id, + payload=request_payload, + headers=request_headers, + ) + notification_store.add_notification(notification) + + logger.info("Received push notification for task: %s", task_id) + + console.print("\n[bold cyan]Push Notification Received[/bold cyan]") + console.print(f"[dim]Timestamp:[/dim] {notification.timestamp.isoformat()}") + if task_id: + console.print(f"[dim]Task ID:[/dim] {task_id}") + + task_status = request_payload.get("status", {}) + if task_status: + task_state = task_status.get("state", "unknown") + console.print(f"[dim]State:[/dim] {task_state}") + + authentication_token = request_headers.get("x-a2a-notification-token") + if authentication_token: + console.print(f"[dim]Token:[/dim] {authentication_token[:20]}...") + + console.print() + console.print_json(json.dumps(request_payload, indent=2, default=str)) + console.print() + + return JSONResponse({"status": "ok", "received": True}) + + +async def handle_webhook_validation(request: Request) -> JSONResponse: + """Handle GET requests for webhook validation.""" + logger.info("Webhook validation request received") + return JSONResponse({"status": "ok", "message": "Webhook is active"}) + + +async def handle_list_notifications(request: Request) -> JSONResponse: + """List all received notifications.""" + all_notifications = notification_store.get_all_notifications() + logger.debug("Returning %d stored notifications", len(all_notifications)) + return JSONResponse( + { + "count": len(all_notifications), + "notifications": [ + { + "timestamp": notification.timestamp.isoformat(), + "task_id": notification.task_id, + "payload": notification.payload, + } + for notification in all_notifications + ], + } + ) + + +async def handle_clear_notifications(request: Request) -> JSONResponse: + """Clear all stored notifications.""" + notification_store.clear_all_notifications() + return JSONResponse({"status": "ok", "message": "Notifications cleared"}) + + +def create_webhook_application() -> Starlette: + """Create the webhook Starlette application.""" + application_routes = [ + Route("/webhook", handle_push_notification, methods=["POST"]), + Route("/webhook", handle_webhook_validation, methods=["GET"]), + Route("/notifications", handle_list_notifications, methods=["GET"]), + Route("/notifications/clear", handle_clear_notifications, methods=["POST"]), + ] + return Starlette(routes=application_routes) + + +def run_webhook_server( + host: str = "127.0.0.1", + port: int = 9000, +) -> None: + """Start the webhook server. + + Args: + host: Host address to bind to + port: Port number to bind to + """ + console.print(f"\n[bold]Starting webhook server on [url]{host}:{port}[/url][/bold]") + console.print() + console.print("[dim]Endpoints:[/dim]") + console.print(f" POST http://{host}:{port}/webhook - Receive notifications") + console.print(f" GET http://{host}:{port}/webhook - Validation check") + console.print(f" GET http://{host}:{port}/notifications - List received") + console.print(f" POST http://{host}:{port}/notifications/clear - Clear stored") + console.print() + console.print( + f"[bold green]Use this URL for push notifications:[/bold green] " + f"http://{host}:{port}/webhook" + ) + console.print() + + logger.info("Starting webhook server on %s:%d", host, port) + webhook_application = create_webhook_application() + uvicorn.run(webhook_application, host=host, port=port, log_level="warning") diff --git a/tests/test_service.py b/tests/test_service.py new file mode 100644 index 0000000..78f02bb --- /dev/null +++ b/tests/test_service.py @@ -0,0 +1,165 @@ +"""Tests for A2A service layer.""" + +from a2a.types import Part, Task, TaskState, TaskStatus, TextPart + +from a2a_handler.service import ( + SendResult, + StreamEvent, + TaskResult, + TERMINAL_TASK_STATES, + extract_text_from_message_parts, +) + + +class TestSendResult: + """Tests for SendResult dataclass.""" + + def test_is_complete_when_completed(self): + """Test is_complete returns True for completed state.""" + result = SendResult(state=TaskState.completed) + assert result.is_complete is True + + def test_is_complete_when_canceled(self): + """Test is_complete returns True for canceled state.""" + result = SendResult(state=TaskState.canceled) + assert result.is_complete is True + + def test_is_complete_when_failed(self): + """Test is_complete returns True for failed state.""" + result = SendResult(state=TaskState.failed) + assert result.is_complete is True + + def test_is_complete_when_rejected(self): + """Test is_complete returns True for rejected state.""" + result = SendResult(state=TaskState.rejected) + assert result.is_complete is True + + def test_is_complete_when_working(self): + """Test is_complete returns False for working state.""" + result = SendResult(state=TaskState.working) + assert result.is_complete is False + + def test_is_complete_when_no_state(self): + """Test is_complete returns False when state is None.""" + result = SendResult() + assert result.is_complete is False + + def test_needs_input_when_input_required(self): + """Test needs_input returns True for input_required state.""" + result = SendResult(state=TaskState.input_required) + assert result.needs_input is True + + def test_needs_input_when_working(self): + """Test needs_input returns False for working state.""" + result = SendResult(state=TaskState.working) + assert result.needs_input is False + + def test_needs_input_when_no_state(self): + """Test needs_input returns False when state is None.""" + result = SendResult() + assert result.needs_input is False + + +class TestStreamEvent: + """Tests for StreamEvent dataclass.""" + + def test_create_message_event(self): + """Test creating a message event.""" + event = StreamEvent( + event_type="message", + context_id="ctx-123", + text="Hello, world!", + ) + + assert event.event_type == "message" + assert event.context_id == "ctx-123" + assert event.text == "Hello, world!" + + def test_create_status_event(self): + """Test creating a status event.""" + event = StreamEvent( + event_type="status", + task_id="task-456", + state=TaskState.working, + ) + + assert event.event_type == "status" + assert event.task_id == "task-456" + assert event.state == TaskState.working + + +class TestTaskResult: + """Tests for TaskResult dataclass.""" + + def test_create_task_result(self): + """Test creating a task result.""" + mock_task = Task( + id="task-123", + context_id="ctx-123", + status=TaskStatus(state=TaskState.completed), + ) + + result = TaskResult( + task=mock_task, + task_id="task-123", + state=TaskState.completed, + context_id="ctx-123", + text="Task completed successfully", + ) + + assert result.task_id == "task-123" + assert result.state == TaskState.completed + assert result.text == "Task completed successfully" + + +class TestExtractTextFromMessageParts: + """Tests for extract_text_from_message_parts function.""" + + def test_extract_from_none(self): + """Test extracting from None returns empty string.""" + result = extract_text_from_message_parts(None) + assert result == "" + + def test_extract_from_empty_list(self): + """Test extracting from empty list returns empty string.""" + result = extract_text_from_message_parts([]) + assert result == "" + + def test_extract_from_text_part_with_root(self): + """Test extracting from TextPart wrapped in Part.""" + parts = [Part(root=TextPart(text="Hello, world!"))] + result = extract_text_from_message_parts(parts) + assert result == "Hello, world!" + + def test_extract_multiple_parts(self): + """Test extracting from multiple parts joins with newlines.""" + parts = [ + Part(root=TextPart(text="First line")), + Part(root=TextPart(text="Second line")), + ] + result = extract_text_from_message_parts(parts) + assert result == "First line\nSecond line" + + +class TestTerminalStates: + """Tests for terminal state constants.""" + + def test_terminal_states_include_completed(self): + """Test that completed is a terminal state.""" + assert TaskState.completed in TERMINAL_TASK_STATES + + def test_terminal_states_include_canceled(self): + """Test that canceled is a terminal state.""" + assert TaskState.canceled in TERMINAL_TASK_STATES + + def test_terminal_states_include_failed(self): + """Test that failed is a terminal state.""" + assert TaskState.failed in TERMINAL_TASK_STATES + + def test_terminal_states_include_rejected(self): + """Test that rejected is a terminal state.""" + assert TaskState.rejected in TERMINAL_TASK_STATES + + def test_working_is_not_terminal(self): + """Test that working is not a terminal state.""" + assert TaskState.working not in TERMINAL_TASK_STATES diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..80463d5 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,185 @@ +"""Tests for session state management.""" + +import tempfile +from pathlib import Path + +from a2a_handler.session import AgentSession, SessionStore + + +class TestAgentSession: + """Tests for AgentSession dataclass.""" + + def test_create_session_with_url_only(self): + """Test creating a session with just the URL.""" + session = AgentSession(agent_url="http://localhost:8000") + + assert session.agent_url == "http://localhost:8000" + assert session.context_id is None + assert session.task_id is None + + def test_create_session_with_all_fields(self): + """Test creating a session with all fields.""" + session = AgentSession( + agent_url="http://localhost:8000", + context_id="ctx-123", + task_id="task-456", + ) + + assert session.agent_url == "http://localhost:8000" + assert session.context_id == "ctx-123" + assert session.task_id == "task-456" + + def test_update_context_id(self): + """Test updating context_id.""" + session = AgentSession(agent_url="http://localhost:8000") + session.update(context_id="new-context") + + assert session.context_id == "new-context" + assert session.task_id is None + + def test_update_task_id(self): + """Test updating task_id.""" + session = AgentSession(agent_url="http://localhost:8000") + session.update(task_id="new-task") + + assert session.context_id is None + assert session.task_id == "new-task" + + def test_update_both_ids(self): + """Test updating both context_id and task_id.""" + session = AgentSession(agent_url="http://localhost:8000") + session.update(context_id="ctx-1", task_id="task-1") + + assert session.context_id == "ctx-1" + assert session.task_id == "task-1" + + def test_update_preserves_existing_values_when_none_passed(self): + """Test that update preserves existing values when None is passed.""" + session = AgentSession( + agent_url="http://localhost:8000", + context_id="existing-ctx", + task_id="existing-task", + ) + session.update() + + assert session.context_id == "existing-ctx" + assert session.task_id == "existing-task" + + +class TestSessionStore: + """Tests for SessionStore.""" + + def test_get_creates_new_session(self): + """Test that get creates a new session if none exists.""" + store = SessionStore() + session = store.get("http://localhost:8000") + + assert session.agent_url == "http://localhost:8000" + assert "http://localhost:8000" in store.sessions + + def test_get_returns_existing_session(self): + """Test that get returns existing session.""" + store = SessionStore() + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000", + context_id="existing-ctx", + ) + + session = store.get("http://localhost:8000") + assert session.context_id == "existing-ctx" + + def test_update_creates_and_updates_session(self): + """Test that update creates and updates session.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + session = store.update( + "http://localhost:8000", + context_id="new-ctx", + task_id="new-task", + ) + + assert session.context_id == "new-ctx" + assert session.task_id == "new-task" + + def test_clear_specific_session(self): + """Test clearing a specific session.""" + store = SessionStore() + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000" + ) + store.sessions["http://localhost:9000"] = AgentSession( + agent_url="http://localhost:9000" + ) + + with tempfile.TemporaryDirectory() as temp_directory: + store.session_directory = Path(temp_directory) + store.clear("http://localhost:8000") + + assert "http://localhost:8000" not in store.sessions + assert "http://localhost:9000" in store.sessions + + def test_clear_all_sessions(self): + """Test clearing all sessions.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000" + ) + store.sessions["http://localhost:9000"] = AgentSession( + agent_url="http://localhost:9000" + ) + + store.clear() + + assert len(store.sessions) == 0 + + def test_list_all_sessions(self): + """Test listing all sessions.""" + store = SessionStore() + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000" + ) + store.sessions["http://localhost:9000"] = AgentSession( + agent_url="http://localhost:9000" + ) + + all_sessions = store.list_all() + assert len(all_sessions) == 2 + + def test_save_and_load_sessions(self): + """Test saving and loading sessions from disk.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000", + context_id="ctx-123", + task_id="task-456", + ) + store.save() + + new_store = SessionStore(session_directory=Path(temp_directory)) + new_store.load() + + assert "http://localhost:8000" in new_store.sessions + loaded_session = new_store.sessions["http://localhost:8000"] + assert loaded_session.context_id == "ctx-123" + assert loaded_session.task_id == "task-456" + + def test_load_nonexistent_file(self): + """Test loading from nonexistent file does nothing.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + store.load() + + assert len(store.sessions) == 0 + + def test_load_invalid_json(self): + """Test loading invalid JSON file handles gracefully.""" + with tempfile.TemporaryDirectory() as temp_directory: + session_file = Path(temp_directory) / "sessions.json" + session_file.write_text("not valid json {{{") + + store = SessionStore(session_directory=Path(temp_directory)) + store.load() + + assert len(store.sessions) == 0 diff --git a/tests/test_tui.py b/tests/test_tui.py index 4afae9c..cd667e3 100644 --- a/tests/test_tui.py +++ b/tests/test_tui.py @@ -1,4 +1,5 @@ import pytest + from a2a_handler.tui import HandlerTUI diff --git a/tests/test_validation.py b/tests/test_validation.py index 49736f9..a3ad6d8 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -56,7 +56,7 @@ def test_missing_required_field(self): assert result.valid is False assert len(result.issues) > 0 - field_names = [i.field for i in result.issues] + field_names = [issue.field_name for issue in result.issues] assert "name" in field_names def test_warnings_for_optional_fields(self): @@ -65,7 +65,7 @@ def test_warnings_for_optional_fields(self): result = validate_agent_card_data(data, "test", ValidationSource.FILE) assert result.valid is True - warning_fields = [w.field for w in result.warnings] + warning_fields = [warning.field_name for warning in result.warnings] assert "provider" in warning_fields assert "documentationUrl" in warning_fields assert "iconUrl" in warning_fields @@ -77,8 +77,8 @@ def test_skill_without_tags_fails_validation(self): result = validate_agent_card_data(data, "test", ValidationSource.FILE) assert result.valid is False - issue_fields = [i.field for i in result.issues] - assert any("skills" in f and "tags" in f for f in issue_fields) + issue_fields = [issue.field_name for issue in result.issues] + assert any("skills" in field and "tags" in field for field in issue_fields) def test_skill_without_examples_generates_warning(self): """Test that skills without examples generate warnings.""" @@ -86,8 +86,8 @@ def test_skill_without_examples_generates_warning(self): result = validate_agent_card_data(data, "test", ValidationSource.FILE) assert result.valid is True - warning_fields = [w.field for w in result.warnings] - assert any("examples" in f for f in warning_fields) + warning_fields = [warning.field_name for warning in result.warnings] + assert any("examples" in field for field in warning_fields) class TestValidateAgentCardFromFile: diff --git a/tests/test_webhook.py b/tests/test_webhook.py new file mode 100644 index 0000000..5f0f655 --- /dev/null +++ b/tests/test_webhook.py @@ -0,0 +1,158 @@ +"""Tests for webhook server.""" + +from datetime import datetime + +import pytest +from starlette.testclient import TestClient + +from a2a_handler.webhook import ( + PushNotification, + PushNotificationStore, + create_webhook_application, +) + + +class TestPushNotification: + """Tests for PushNotification dataclass.""" + + def test_create_notification(self): + """Test creating a notification.""" + timestamp = datetime.now() + notification = PushNotification( + timestamp=timestamp, + task_id="task-123", + payload={"status": {"state": "completed"}}, + headers={"content-type": "application/json"}, + ) + + assert notification.timestamp == timestamp + assert notification.task_id == "task-123" + assert notification.payload["status"]["state"] == "completed" + assert notification.headers["content-type"] == "application/json" + + def test_create_notification_without_task_id(self): + """Test creating a notification without task_id.""" + notification = PushNotification( + timestamp=datetime.now(), + task_id=None, + payload={}, + headers={}, + ) + + assert notification.task_id is None + + +class TestPushNotificationStore: + """Tests for PushNotificationStore.""" + + def test_add_notification(self): + """Test adding a notification to the store.""" + store = PushNotificationStore() + notification = PushNotification( + timestamp=datetime.now(), + task_id="task-123", + payload={}, + headers={}, + ) + + store.add_notification(notification) + assert len(store.notifications) == 1 + + def test_get_all_notifications(self): + """Test getting all notifications.""" + store = PushNotificationStore() + notification1 = PushNotification( + timestamp=datetime.now(), + task_id="task-1", + payload={}, + headers={}, + ) + notification2 = PushNotification( + timestamp=datetime.now(), + task_id="task-2", + payload={}, + headers={}, + ) + + store.add_notification(notification1) + store.add_notification(notification2) + + all_notifications = store.get_all_notifications() + assert len(all_notifications) == 2 + + def test_clear_all_notifications(self): + """Test clearing all notifications.""" + store = PushNotificationStore() + notification = PushNotification( + timestamp=datetime.now(), + task_id="task-123", + payload={}, + headers={}, + ) + + store.add_notification(notification) + store.clear_all_notifications() + + assert len(store.notifications) == 0 + + +class TestWebhookApplication: + """Tests for the webhook Starlette application.""" + + @pytest.fixture + def client(self): + """Create a test client for the webhook application.""" + application = create_webhook_application() + return TestClient(application) + + def test_webhook_validation_get(self, client): + """Test GET request for webhook validation.""" + response = client.get("/webhook") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "message" in data + + def test_webhook_receive_notification(self, client): + """Test POST request to receive notification.""" + payload = { + "id": "task-123", + "status": {"state": "completed"}, + } + + response = client.post("/webhook", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["received"] is True + + def test_webhook_invalid_json(self, client): + """Test POST request with invalid JSON.""" + response = client.post( + "/webhook", + content="not valid json", + headers={"content-type": "application/json"}, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + + def test_list_notifications(self, client): + """Test listing received notifications.""" + response = client.get("/notifications") + + assert response.status_code == 200 + data = response.json() + assert "count" in data + assert "notifications" in data + + def test_clear_notifications(self, client): + """Test clearing notifications.""" + response = client.post("/notifications/clear") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" From 07161bb9288c653492529289c94d566f43633c9d Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Sun, 7 Dec 2025 21:02:59 -0500 Subject: [PATCH 07/23] fix: check for TextPart root when extracting text --- src/a2a_handler/service.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py index 1bae5d7..aa102b8 100644 --- a/src/a2a_handler/service.py +++ b/src/a2a_handler/service.py @@ -96,10 +96,8 @@ def extract_text_from_message_parts(message_parts: list[Part] | None) -> str: extracted_texts = [] for part in message_parts: - if hasattr(part, "root") and hasattr(part.root, "text"): + if isinstance(part.root, TextPart): extracted_texts.append(part.root.text) - elif hasattr(part, "text"): - extracted_texts.append(part.text) return "\n".join(text for text in extracted_texts if text) From a345505689d5e7edb804aef3ceb2f32ff95dacb1 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 17:12:42 -0500 Subject: [PATCH 08/23] refactor(cli): imports and silence noisy logs --- pyproject.toml | 3 +++ src/a2a_handler/cli.py | 48 +++++++++++++++++++++--------------------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3305ae..95bf14a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,3 +69,6 @@ build-backend = "uv_build" [tool.uv.build-backend] module-root = "src" + +[tool.ruff.lint.per-file-ignores] +"src/a2a_handler/cli.py" = ["E402"] diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index ded8ee5..53eaa8e 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -3,9 +3,18 @@ import logging from typing import Any, Optional +# Suppress noisy third-party debug logs during import +logging.getLogger().setLevel(logging.WARNING) + import httpx import rich_click as click +from a2a.client.errors import ( + A2AClientError, + A2AClientHTTPError, + A2AClientTimeoutError, +) + from a2a_handler import __version__ from a2a_handler.common import ( console, @@ -16,6 +25,21 @@ print_panel, setup_logging, ) +from a2a_handler.server import run_server +from a2a_handler.service import A2AService, SendResult, TaskResult +from a2a_handler.session import ( + clear_session, + get_session, + get_session_store, + update_session, +) +from a2a_handler.tui import HandlerTUI +from a2a_handler.validation import ( + ValidationResult, + validate_agent_card_from_file, + validate_agent_card_from_url, +) +from a2a_handler.webhook import run_webhook_server click.rich_click.USE_RICH_MARKUP = True click.rich_click.USE_MARKDOWN = True @@ -95,30 +119,6 @@ ], } -setup_logging(level="WARNING") - -from a2a.client.errors import ( # noqa: E402 - A2AClientError, - A2AClientHTTPError, - A2AClientTimeoutError, -) - -from a2a_handler.server import run_server # noqa: E402 -from a2a_handler.service import A2AService, SendResult, TaskResult # noqa: E402 -from a2a_handler.session import ( # noqa: E402 - clear_session, - get_session, - get_session_store, - update_session, -) -from a2a_handler.tui import HandlerTUI # noqa: E402 -from a2a_handler.validation import ( # noqa: E402 - ValidationResult, - validate_agent_card_from_file, - validate_agent_card_from_url, -) -from a2a_handler.webhook import run_webhook_server # noqa: E402 - TIMEOUT = 120 From b883d4d9c948894dc8a44c07bbff862ceef94153 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 17:28:43 -0500 Subject: [PATCH 09/23] refactor(cli): extract formatting utils to common module --- src/a2a_handler/cli.py | 84 ++----------------- src/a2a_handler/common/__init__.py | 6 ++ src/a2a_handler/common/formatting.py | 104 ++++++++++++++++++++++++ tests/test_formatting.py | 116 +++++++++++++++++++++++++++ 4 files changed, 231 insertions(+), 79 deletions(-) create mode 100644 src/a2a_handler/common/formatting.py create mode 100644 tests/test_formatting.py diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 53eaa8e..28b77dc 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from typing import Any, Optional +from typing import Optional # Suppress noisy third-party debug logs during import logging.getLogger().setLevel(logging.WARNING) @@ -18,6 +18,8 @@ from a2a_handler import __version__ from a2a_handler.common import ( console, + format_field_name, + format_value, get_logger, print_error, print_json, @@ -180,82 +182,6 @@ def cli(ctx, verbose: bool, debug: bool) -> None: setup_logging(level="WARNING") -def _format_field_name(name: str) -> str: - """Convert snake_case or camelCase to Title Case.""" - import re - - name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name) - name = name.replace("_", " ") - return name.title() - - -def _format_value(value: Any, indent: int = 0) -> str: - """Recursively format a value for display, returning only truthy content.""" - prefix = " " * indent - - if value is None or value == "" or value == [] or value == {}: - return "" - - if isinstance(value, bool): - return "✓" if value else "✗" - - if isinstance(value, str): - return value - - if isinstance(value, int | float): - return str(value) - - if isinstance(value, list): - lines: list[str] = [] - for item in value: - if hasattr(item, "model_dump"): - item_dict: dict[str, Any] = item.model_dump() - name = item_dict.get("name") or item_dict.get("id") or "Item" - desc = item_dict.get("description") or "" - if desc: - desc_prefix = " " * (indent + 1) - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - lines.append(f"{desc_prefix} {desc}") - else: - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - elif isinstance(item, dict): - item_d: dict[str, Any] = item - name = item_d.get("name") or item_d.get("id") or "Item" - desc = item_d.get("description") or "" - if desc: - desc_prefix = " " * (indent + 1) - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - lines.append(f"{desc_prefix} {desc}") - else: - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - else: - formatted = _format_value(item, indent) - if formatted: - lines.append(f"{prefix} • {formatted}") - return "\n" + "\n".join(lines) if lines else "" - - if hasattr(value, "model_dump"): - value = value.model_dump() - - if isinstance(value, dict): - dict_lines: list[str] = [] - for k, v in value.items(): - if isinstance(k, str) and k.startswith("_"): - continue - formatted = _format_value(v, indent + 1) - if formatted: - field_name = _format_field_name(str(k)) - if "\n" in formatted: - dict_lines.append( - f"{prefix}[bold]{field_name}:[/bold]\n{formatted}" - ) - else: - dict_lines.append(f"{prefix}[bold]{field_name}:[/bold] {formatted}") - return "\n".join(dict_lines) if dict_lines else "" - - return str(value) if value else "" - - def _format_send_result(result: SendResult, output: str) -> None: """Format and display a send result.""" if output == "json": @@ -351,9 +277,9 @@ async def fetch() -> None: for key, value in card_dict.items(): if key.startswith("_"): continue - formatted = _format_value(value) + formatted = format_value(value) if formatted: - field_name = _format_field_name(key) + field_name = format_field_name(key) if "\n" in formatted: content_parts.append( f"[bold]{field_name}:[/bold]\n{formatted}" diff --git a/src/a2a_handler/common/__init__.py b/src/a2a_handler/common/__init__.py index ab089ac..1d587f8 100644 --- a/src/a2a_handler/common/__init__.py +++ b/src/a2a_handler/common/__init__.py @@ -1,5 +1,9 @@ """Common utilities for Handler.""" +from .formatting import ( + format_field_name, + format_value, +) from .logging import ( HANDLER_THEME, LogLevel, @@ -23,6 +27,8 @@ "HANDLER_THEME", "LogLevel", "console", + "format_field_name", + "format_value", "get_logger", "print_error", "print_info", diff --git a/src/a2a_handler/common/formatting.py b/src/a2a_handler/common/formatting.py new file mode 100644 index 0000000..11ac563 --- /dev/null +++ b/src/a2a_handler/common/formatting.py @@ -0,0 +1,104 @@ +"""Formatting utilities for Handler.""" + +import re +from typing import Any + + +def format_field_name(name: str) -> str: + """Convert snake_case or camelCase to Title Case. + + Args: + name: The field name to format. + + Returns: + The formatted field name in Title Case. + + Examples: + >>> format_field_name("snake_case") + 'Snake Case' + >>> format_field_name("camelCase") + 'Camel Case' + >>> format_field_name("already Title") + 'Already Title' + """ + name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name) + name = name.replace("_", " ") + return name.title() + + +def format_value(value: Any, indent: int = 0) -> str: + """Recursively format a value for display, returning only truthy content. + + Handles None, empty strings, empty lists/dicts, bools, strings, numbers, + lists (including Pydantic models), dicts, and objects with model_dump(). + + Args: + value: The value to format. + indent: The current indentation level. + + Returns: + A formatted string representation, or empty string for falsy values. + """ + prefix = " " * indent + + if value is None or value == "" or value == [] or value == {}: + return "" + + if isinstance(value, bool): + return "✓" if value else "✗" + + if isinstance(value, str): + return value + + if isinstance(value, int | float): + return str(value) + + if isinstance(value, list): + lines: list[str] = [] + for item in value: + if hasattr(item, "model_dump"): + item_dict: dict[str, Any] = item.model_dump() + name = item_dict.get("name") or item_dict.get("id") or "Item" + desc = item_dict.get("description") or "" + if desc: + desc_prefix = " " * (indent + 1) + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + lines.append(f"{desc_prefix} {desc}") + else: + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + elif isinstance(item, dict): + item_d: dict[str, Any] = item + name = item_d.get("name") or item_d.get("id") or "Item" + desc = item_d.get("description") or "" + if desc: + desc_prefix = " " * (indent + 1) + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + lines.append(f"{desc_prefix} {desc}") + else: + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + else: + formatted = format_value(item, indent) + if formatted: + lines.append(f"{prefix} • {formatted}") + return "\n" + "\n".join(lines) if lines else "" + + if hasattr(value, "model_dump"): + value = value.model_dump() + + if isinstance(value, dict): + dict_lines: list[str] = [] + for k, v in value.items(): + if isinstance(k, str) and k.startswith("_"): + continue + formatted = format_value(v, indent + 1) + if formatted: + field_name = format_field_name(str(k)) + if "\n" in formatted: + dict_lines.append( + f"{prefix}[bold]{field_name}:[/bold]\n{formatted}" + ) + else: + dict_lines.append(f"{prefix}[bold]{field_name}:[/bold] {formatted}") + return "\n".join(dict_lines) if dict_lines else "" + + return str(value) if value else "" diff --git a/tests/test_formatting.py b/tests/test_formatting.py new file mode 100644 index 0000000..b3825c0 --- /dev/null +++ b/tests/test_formatting.py @@ -0,0 +1,116 @@ +"""Tests for formatting utilities.""" + +from a2a_handler.common.formatting import format_field_name, format_value + + +class TestFormatFieldName: + """Tests for format_field_name function.""" + + def test_snake_case(self): + """Test conversion of snake_case to Title Case.""" + assert format_field_name("snake_case") == "Snake Case" + assert format_field_name("some_field_name") == "Some Field Name" + + def test_camel_case(self): + """Test conversion of camelCase to Title Case.""" + assert format_field_name("camelCase") == "Camel Case" + assert format_field_name("someFieldName") == "Some Field Name" + + def test_already_spaced(self): + """Test handling of already spaced strings.""" + assert format_field_name("already spaced") == "Already Spaced" + + def test_single_word(self): + """Test single word remains capitalized.""" + assert format_field_name("name") == "Name" + assert format_field_name("id") == "Id" + + def test_mixed_case_with_underscores(self): + """Test mixed camelCase with underscores.""" + assert format_field_name("some_camelCase_field") == "Some Camel Case Field" + + +class TestFormatValue: + """Tests for format_value function.""" + + def test_none_returns_empty(self): + """Test None returns empty string.""" + assert format_value(None) == "" + + def test_empty_string_returns_empty(self): + """Test empty string returns empty string.""" + assert format_value("") == "" + + def test_empty_list_returns_empty(self): + """Test empty list returns empty string.""" + assert format_value([]) == "" + + def test_empty_dict_returns_empty(self): + """Test empty dict returns empty string.""" + assert format_value({}) == "" + + def test_bool_true(self): + """Test True returns checkmark.""" + assert format_value(True) == "✓" + + def test_bool_false(self): + """Test False returns X.""" + assert format_value(False) == "✗" + + def test_string(self): + """Test string returns itself.""" + assert format_value("hello") == "hello" + assert format_value("hello world") == "hello world" + + def test_integer(self): + """Test integer returns string representation.""" + assert format_value(42) == "42" + assert format_value(0) == "0" + + def test_float(self): + """Test float returns string representation.""" + assert format_value(3.14) == "3.14" + + def test_simple_dict(self): + """Test simple dict formatting.""" + result = format_value({"name": "test"}) + assert "Name:" in result + assert "test" in result + + def test_dict_skips_private_keys(self): + """Test dict skips keys starting with underscore.""" + result = format_value({"name": "test", "_private": "hidden"}) + assert "Name:" in result + assert "_private" not in result + assert "hidden" not in result + + def test_list_of_strings(self): + """Test list of strings formatting.""" + result = format_value(["a", "b", "c"]) + assert "• a" in result + assert "• b" in result + assert "• c" in result + + def test_list_of_dicts_with_name(self): + """Test list of dicts uses name field.""" + result = format_value([{"name": "Item1"}, {"name": "Item2"}]) + assert "Item1" in result + assert "Item2" in result + + def test_list_of_dicts_with_description(self): + """Test list of dicts includes description.""" + result = format_value([{"name": "Item1", "description": "Description1"}]) + assert "Item1" in result + assert "Description1" in result + + def test_nested_dict(self): + """Test nested dict formatting.""" + result = format_value({"outer": {"inner": "value"}}) + assert "Outer:" in result + assert "Inner:" in result + assert "value" in result + + def test_indentation(self): + """Test indentation is applied.""" + result = format_value({"name": "test"}, indent=1) + assert result.startswith(" ") From 9e3c485b9469cdca13bf83b7c980eb307a6a3819 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 18:17:01 -0500 Subject: [PATCH 10/23] chore: remove PyPI badge due to rate limiting --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index e89bec9..262faf1 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ [![A2A Protocol](https://img.shields.io/badge/A2A_Protocol-v0.3.0-blue)](https://a2a-protocol.org/latest/) [![PyPI version](https://img.shields.io/pypi/v/a2a-handler)](https://pypi.org/project/a2a-handler/) [![PyPI - Status](https://img.shields.io/pypi/status/a2a-handler)](https://pypi.org/project/a2a-handler/) -[![PyPI downloads](https://img.shields.io/pypi/dm/a2a-handler)](https://pypi.org/project/a2a-handler/) [![GitHub stars](https://img.shields.io/github/stars/alDuncanson/handler)](https://github.com/alDuncanson/handler/stargazers) An [A2A](https://a2a-protocol.org/latest/) Protocol client TUI and CLI. From c0f0154334bf2a91902b6416b867d60a683eb363 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 18:24:10 -0500 Subject: [PATCH 11/23] refactor(cli): agent card validation --- src/a2a_handler/cli.py | 23 +--- src/a2a_handler/validation.py | 242 +++++++++------------------------- tests/test_validation.py | 99 +++++++------- 3 files changed, 118 insertions(+), 246 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 28b77dc..3dd54f7 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -179,7 +179,7 @@ def cli(ctx, verbose: bool, debug: bool) -> None: log.debug("Verbose logging enabled") setup_logging(level="INFO") else: - setup_logging(level="WARNING") + setup_logging(level="ERROR") def _format_send_result(result: SendResult, output: str) -> None: @@ -315,14 +315,6 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: } for issue in result.issues ], - "warnings": [ - { - "field": warning.field_name, - "message": warning.message, - "type": warning.issue_type, - } - for warning in result.warnings - ], } print_json(json.dumps(output_data, indent=2)) return @@ -334,17 +326,6 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: f"[bold]Protocol Version:[/bold] {result.protocol_version}", f"[bold]Source:[/bold] {result.source}", ] - - if result.warnings: - content_parts.append("") - content_parts.append( - f"[bold yellow]Warnings ({len(result.warnings)}):[/bold yellow]" - ) - for warning in result.warnings: - content_parts.append( - f" [yellow]⚠[/yellow] {warning.field_name}: {warning.message}" - ) - print_panel("\n".join(content_parts), title=title) else: title = "[bold red]✗ Invalid Agent Card[/bold red]" @@ -394,7 +375,7 @@ async def do_validate() -> None: _format_validation_result(result, output) if not result.valid: - raise click.Abort() + raise SystemExit(1) asyncio.run(do_validate()) diff --git a/src/a2a_handler/validation.py b/src/a2a_handler/validation.py index 6d876c0..bc36bdc 100644 --- a/src/a2a_handler/validation.py +++ b/src/a2a_handler/validation.py @@ -7,6 +7,7 @@ from typing import Any import httpx +from a2a.client import A2ACardResolver from a2a.types import AgentCard from pydantic import ValidationError @@ -30,9 +31,6 @@ class ValidationIssue: message: str issue_type: str = "error" - def __str__(self) -> str: - return f"[{self.issue_type}] {self.field_name}: {self.message}" - @dataclass class ValidationResult: @@ -43,7 +41,6 @@ class ValidationResult: source_type: ValidationSource agent_card: AgentCard | None = None issues: list[ValidationIssue] = field(default_factory=list) - warnings: list[ValidationIssue] = field(default_factory=list) raw_data: dict[str, Any] | None = None @property @@ -59,160 +56,36 @@ def agent_name(self) -> str: def protocol_version(self) -> str: """Get the protocol version if available.""" if self.agent_card: - return self.agent_card.protocol_version or "1.0" + return self.agent_card.protocol_version or "Unknown" if self.raw_data: - return self.raw_data.get("protocolVersion", "1.0") + return self.raw_data.get("protocolVersion", "Unknown") return "Unknown" -def parse_pydantic_validation_errors( - validation_error: ValidationError, -) -> list[ValidationIssue]: +def _parse_validation_errors(error: ValidationError) -> list[ValidationIssue]: """Parse Pydantic validation errors into ValidationIssues.""" - validation_issues = [] - for error_detail in validation_error.errors(): - field_path = ".".join(str(location) for location in error_detail["loc"]) - error_message = error_detail["msg"] - error_type = error_detail["type"] - validation_issues.append( + issues = [] + for detail in error.errors(): + field_path = ".".join(str(loc) for loc in detail["loc"]) + issues.append( ValidationIssue( field_name=field_path or "root", - message=error_message, - issue_type=error_type, - ) - ) - return validation_issues - - -def check_agent_card_best_practices(agent_card: AgentCard) -> list[ValidationIssue]: - """Check for best practices and generate warnings. - - Note: In A2A v0.3.0, the following are REQUIRED fields and validated by Pydantic: - - name, description, url, version - - capabilities, defaultInputModes, defaultOutputModes, skills - - preferredTransport (defaults to JSONRPC in SDK) - - This function only warns about optional fields that improve agent discoverability. - """ - best_practice_warnings = [] - - if not agent_card.provider: - best_practice_warnings.append( - ValidationIssue( - field_name="provider", - message="Agent card should specify a provider for better discoverability", - issue_type="warning", - ) - ) - - if not agent_card.documentation_url: - best_practice_warnings.append( - ValidationIssue( - field_name="documentationUrl", - message="Agent card should include documentation URL", - issue_type="warning", - ) - ) - - if not agent_card.icon_url: - best_practice_warnings.append( - ValidationIssue( - field_name="iconUrl", - message="Agent card should include an icon URL for UI display", - issue_type="warning", - ) - ) - - if agent_card.skills: - for skill_index, skill in enumerate(agent_card.skills): - if not skill.description: - best_practice_warnings.append( - ValidationIssue( - field_name=f"skills[{skill_index}].description", - message=f"Skill '{skill.name}' should have a description", - issue_type="warning", - ) - ) - if not skill.examples or len(skill.examples) == 0: - best_practice_warnings.append( - ValidationIssue( - field_name=f"skills[{skill_index}].examples", - message=f"Skill '{skill.name}' should include example prompts", - issue_type="warning", - ) - ) - - if ( - not agent_card.additional_interfaces - or len(agent_card.additional_interfaces) == 0 - ): - best_practice_warnings.append( - ValidationIssue( - field_name="additionalInterfaces", - message="Consider declaring additional transport interfaces for flexibility", - issue_type="warning", + message=detail["msg"], + issue_type=detail["type"], ) ) - - return best_practice_warnings - - -def validate_agent_card_data( - card_data: dict[str, Any], - source: str, - source_type: ValidationSource, -) -> ValidationResult: - """Validate agent card data against the A2A protocol schema. - - Args: - card_data: Raw agent card data as a dictionary - source: The source (URL or file path) of the data - source_type: Whether the source is a URL or file - - Returns: - ValidationResult with validation status and any issues - """ - logger.debug("Validating agent card data from %s", source) - - try: - agent_card = AgentCard.model_validate(card_data) - logger.info("Agent card validation successful for %s", agent_card.name) - - best_practice_warnings = check_agent_card_best_practices(agent_card) - - return ValidationResult( - valid=True, - source=source, - source_type=source_type, - agent_card=agent_card, - warnings=best_practice_warnings, - raw_data=card_data, - ) - - except ValidationError as validation_error: - logger.warning("Agent card validation failed: %s", validation_error) - validation_issues = parse_pydantic_validation_errors(validation_error) - - return ValidationResult( - valid=False, - source=source, - source_type=source_type, - issues=validation_issues, - raw_data=card_data, - ) + return issues async def validate_agent_card_from_url( agent_url: str, http_client: httpx.AsyncClient | None = None, - agent_card_path: str | None = None, ) -> ValidationResult: - """Fetch and validate an agent card from a URL. + """Fetch and validate an agent card from a URL using the A2A SDK. Args: agent_url: The base URL of the agent http_client: Optional HTTP client to use - agent_card_path: Optional custom path to the agent card (default: /.well-known/agent.json) Returns: ValidationResult with validation status and any issues @@ -224,60 +97,52 @@ async def validate_agent_card_from_url( http_client = httpx.AsyncClient(timeout=30) try: - base_url = agent_url.rstrip("/") - if agent_card_path: - full_url = f"{base_url}/{agent_card_path.lstrip('/')}" - else: - full_url = f"{base_url}/.well-known/agent-card.json" + resolver = A2ACardResolver(http_client, agent_url) + agent_card = await resolver.get_agent_card() - logger.debug("Fetching agent card from %s", full_url) - response = await http_client.get(full_url) - response.raise_for_status() - - card_data = response.json() - return validate_agent_card_data(card_data, agent_url, ValidationSource.URL) + logger.info("Agent card validation successful for %s", agent_card.name) + return ValidationResult( + valid=True, + source=agent_url, + source_type=ValidationSource.URL, + agent_card=agent_card, + ) - except httpx.HTTPStatusError as http_error: - logger.error("HTTP error fetching agent card: %s", http_error) + except ValidationError as e: + logger.warning("Agent card validation failed: %s", e) return ValidationResult( valid=False, source=agent_url, source_type=ValidationSource.URL, - issues=[ - ValidationIssue( - field_name="http", - message=f"HTTP {http_error.response.status_code}: {http_error.response.text[:200]}", - issue_type="http_error", - ) - ], + issues=_parse_validation_errors(e), ) - except httpx.RequestError as request_error: - logger.error("Request error fetching agent card: %s", request_error) + except httpx.HTTPStatusError as e: + logger.error("HTTP error fetching agent card: %s", e) return ValidationResult( valid=False, source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field_name="connection", - message=str(request_error), - issue_type="connection_error", + field_name="http", + message=f"HTTP {e.response.status_code}: {e.response.text[:200]}", + issue_type="http_error", ) ], ) - except json.JSONDecodeError as json_error: - logger.error("JSON decode error: %s", json_error) + except httpx.RequestError as e: + logger.error("Request error fetching agent card: %s", e) return ValidationResult( valid=False, source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field_name="json", - message=f"Invalid JSON: {json_error}", - issue_type="json_error", + field_name="connection", + message=str(e), + issue_type="connection_error", ) ], ) @@ -329,14 +194,35 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ], ) + card_data: dict[str, Any] | None = None + try: - with open(path, encoding="utf-8") as card_file: - card_data = json.load(card_file) + with open(path, encoding="utf-8") as f: + card_data = json.load(f) - return validate_agent_card_data(card_data, str(path), ValidationSource.FILE) + agent_card = AgentCard.model_validate(card_data) + logger.info("Agent card validation successful for %s", agent_card.name) + + return ValidationResult( + valid=True, + source=str(path), + source_type=ValidationSource.FILE, + agent_card=agent_card, + raw_data=card_data, + ) + + except ValidationError as e: + logger.warning("Agent card validation failed: %s", e) + return ValidationResult( + valid=False, + source=str(path), + source_type=ValidationSource.FILE, + issues=_parse_validation_errors(e), + raw_data=card_data, + ) - except json.JSONDecodeError as json_error: - logger.error("JSON decode error: %s", json_error) + except json.JSONDecodeError as e: + logger.error("JSON decode error: %s", e) return ValidationResult( valid=False, source=str(path), @@ -344,7 +230,7 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: issues=[ ValidationIssue( field_name="json", - message=f"Invalid JSON at line {json_error.lineno}, column {json_error.colno}: {json_error.msg}", + message=f"Invalid JSON at line {e.lineno}, column {e.colno}: {e.msg}", issue_type="json_error", ) ], @@ -365,8 +251,8 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ], ) - except OSError as os_error: - logger.error("Error reading file: %s", os_error) + except OSError as e: + logger.error("Error reading file: %s", e) return ValidationResult( valid=False, source=str(path), @@ -374,7 +260,7 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: issues=[ ValidationIssue( field_name="file", - message=str(os_error), + message=str(e), issue_type="file_error", ) ], diff --git a/tests/test_validation.py b/tests/test_validation.py index a3ad6d8..275fa53 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -4,9 +4,10 @@ import tempfile from pathlib import Path +from a2a.types import AgentCard + from a2a_handler.validation import ( ValidationSource, - validate_agent_card_data, validate_agent_card_from_file, ) @@ -36,58 +37,38 @@ def _minimal_valid_agent_card() -> dict: } -class TestValidateAgentCardData: - """Tests for validate_agent_card_data function.""" +class TestAgentCardValidation: + """Tests for agent card validation using the A2A SDK.""" def test_valid_minimal_card(self): """Test validation of a minimal valid agent card.""" data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) + card = AgentCard.model_validate(data) - assert result.valid is True - assert result.agent_card is not None - assert result.agent_card.name == "Test Agent" - assert len(result.issues) == 0 + assert card.name == "Test Agent" + assert card.description == "A test agent" + assert len(card.skills) == 1 def test_missing_required_field(self): """Test validation fails when required field is missing.""" data = {"url": "http://localhost:8000"} - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - - assert result.valid is False - assert len(result.issues) > 0 - field_names = [issue.field_name for issue in result.issues] - assert "name" in field_names - - def test_warnings_for_optional_fields(self): - """Test warnings are generated for missing optional fields.""" - data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.valid is True - warning_fields = [warning.field_name for warning in result.warnings] - assert "provider" in warning_fields - assert "documentationUrl" in warning_fields - assert "iconUrl" in warning_fields + try: + AgentCard.model_validate(data) + assert False, "Expected validation to fail" + except Exception: + pass def test_skill_without_tags_fails_validation(self): """Test that skills without tags fail validation (tags are required in v0.3.0).""" data = _minimal_valid_agent_card() data["skills"] = [{"id": "test", "name": "Test", "description": "Test desc"}] - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.valid is False - issue_fields = [issue.field_name for issue in result.issues] - assert any("skills" in field and "tags" in field for field in issue_fields) - - def test_skill_without_examples_generates_warning(self): - """Test that skills without examples generate warnings.""" - data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - - assert result.valid is True - warning_fields = [warning.field_name for warning in result.warnings] - assert any("examples" in field for field in warning_fields) + try: + AgentCard.model_validate(data) + assert False, "Expected validation to fail" + except Exception: + pass class TestValidateAgentCardFromFile: @@ -147,30 +128,54 @@ class TestValidationResult: def test_agent_name_from_card(self): """Test agent_name property returns name from agent card.""" data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.agent_name == "Test Agent" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.agent_name == "Test Agent" + + Path(f.name).unlink() def test_agent_name_from_raw_data(self): """Test agent_name property returns name from raw data when card is None.""" data = {"name": "Raw Agent", "url": "invalid"} - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.valid is False - assert result.agent_name == "Raw Agent" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.valid is False + assert result.agent_name == "Raw Agent" + + Path(f.name).unlink() def test_protocol_version_from_sdk(self): """Test protocol_version returns the SDK default version.""" data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.protocol_version is not None - assert len(result.protocol_version) > 0 + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.protocol_version is not None + assert len(result.protocol_version) > 0 + + Path(f.name).unlink() def test_protocol_version_explicit(self): """Test protocol_version returns explicit version when set.""" data = _minimal_valid_agent_card() data["protocolVersion"] = "2.0" - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.protocol_version == "2.0" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.protocol_version == "2.0" + + Path(f.name).unlink() From 3876e28000745b4129f324a2e913f40204687ca0 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 20:05:00 -0500 Subject: [PATCH 12/23] refactor(cli): move push commands under tasks group --- src/a2a_handler/cli.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 3dd54f7..b4307c4 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -90,7 +90,7 @@ "handler": [ { "name": "Agent Commands", - "commands": ["card", "send", "validate", "tasks", "push"], + "commands": ["card", "send", "validate", "tasks"], }, { "name": "Interface Commands", @@ -106,11 +106,9 @@ "name": "Task Commands", "commands": ["get", "cancel", "resubscribe"], }, - ], - "handler push": [ { "name": "Push Notification Commands", - "commands": ["set", "get"], + "commands": ["push-set", "push-get"], }, ], "handler session": [ @@ -653,13 +651,7 @@ async def resubscribe() -> None: asyncio.run(resubscribe()) -@cli.group() -def push() -> None: - """Manage push notification configurations.""" - pass - - -@push.command("set") +@tasks.command("push-set") @click.argument("agent_url") @click.argument("task_id") @click.option("--url", "-u", required=True, help="Webhook URL to receive notifications") @@ -671,7 +663,7 @@ def push() -> None: default="text", help="Output format", ) -def push_set( +def tasks_push_set( agent_url: str, task_id: str, url: str, @@ -684,7 +676,7 @@ def push_set( when task status changes. Example: - handler push set http://localhost:8000 TASK_ID --url http://localhost:9000/webhook + handler tasks push-set http://localhost:8000 TASK_ID --url http://localhost:9000/webhook """ log.info("Setting push config for task %s at %s", task_id, agent_url) @@ -718,7 +710,7 @@ async def set_push() -> None: asyncio.run(set_push()) -@push.command("get") +@tasks.command("push-get") @click.argument("agent_url") @click.argument("task_id") @click.argument("config_id") @@ -729,7 +721,7 @@ async def set_push() -> None: default="text", help="Output format", ) -def push_get( +def tasks_push_get( agent_url: str, task_id: str, config_id: str, From cd521370a5fb64366867177f44135e2fb9098071 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 20:43:24 -0500 Subject: [PATCH 13/23] feat: print push notification info from config --- src/a2a_handler/cli.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index b4307c4..0d1b31d 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -698,10 +698,14 @@ async def set_push() -> None: console.print( "[green]Push notification config set successfully[/green]" ) - console.print(f"[bold]Task ID:[/bold] {task_id}") - console.print(f"[bold]Webhook URL:[/bold] {url}") - if token: - console.print(f"[bold]Token:[/bold] {token[:20]}...") + console.print(f"[bold]Task ID:[/bold] {config.task_id}") + if config.push_notification_config: + pnc = config.push_notification_config + console.print(f"[bold]URL:[/bold] {pnc.url}") + if pnc.token: + console.print(f"[bold]Token:[/bold] {pnc.token[:20]}...") + if pnc.id: + console.print(f"[bold]Config ID:[/bold] {pnc.id}") except Exception as e: _handle_client_error(e, agent_url) From 24285e51d34a048b2243689c13abf3077db34421 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Mon, 8 Dec 2025 23:18:52 -0500 Subject: [PATCH 14/23] refactor(cli): major rewrite of command structure and output formatting --- src/a2a_handler/cli.py | 1447 ++++++++++++++++++---------- src/a2a_handler/common/__init__.py | 24 +- src/a2a_handler/common/output.py | 177 ++++ src/a2a_handler/common/printing.py | 79 -- src/a2a_handler/service.py | 79 +- 5 files changed, 1176 insertions(+), 630 deletions(-) create mode 100644 src/a2a_handler/common/output.py delete mode 100644 src/a2a_handler/common/printing.py diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 0d1b31d..0bf2d76 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -1,14 +1,22 @@ +"""Handler CLI - A2A protocol client. + +Command structure based on A2A protocol method mapping: +- message send/stream: Send messages to agents +- task get/cancel/resubscribe: Manage tasks +- task notification set/get/list/delete: Push notification configs +- card get/validate: Agent card operations +- server agent/push: Run local servers +- session list/show/get/clear: Manage saved sessions +""" + import asyncio -import json import logging from typing import Optional -# Suppress noisy third-party debug logs during import logging.getLogger().setLevel(logging.WARNING) import httpx import rich_click as click - from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, @@ -17,14 +25,10 @@ from a2a_handler import __version__ from a2a_handler.common import ( - console, format_field_name, format_value, get_logger, - print_error, - print_json, - print_markdown, - print_panel, + get_output_context, setup_logging, ) from a2a_handler.server import run_server @@ -43,6 +47,7 @@ ) from a2a_handler.webhook import run_webhook_server +# rich_click configuration click.rich_click.USE_RICH_MARKUP = True click.rich_click.USE_MARKDOWN = True click.rich_click.SHOW_ARGUMENTS = True @@ -52,14 +57,13 @@ click.rich_click.STYLE_ARGUMENT = "cyan" click.rich_click.STYLE_COMMAND = "green" click.rich_click.STYLE_SWITCH = "bold green" + click.rich_click.OPTION_GROUPS = { "handler": [ - { - "name": "Global Options", - "options": ["--verbose", "--debug", "--help"], - }, + {"name": "Global Options", "options": ["--verbose", "--debug", "--help"]}, + {"name": "Output Options", "options": ["--raw"]}, ], - "handler send": [ + "handler message send": [ { "name": "Message Options", "options": ["--stream", "--continue", "--context-id", "--task-id"], @@ -68,58 +72,73 @@ "name": "Push Notification Options", "options": ["--push-url", "--push-token"], }, - { - "name": "Output Options", - "options": ["--output", "--help"], - }, + {"name": "Output Options", "options": ["--output", "--help"]}, ], - "handler server": [ + "handler message stream": [ { - "name": "Server Options", - "options": ["--host", "--port", "--help"], + "name": "Conversation Options", + "options": ["--continue", "--context-id", "--task-id"], }, - ], - "handler tasks get": [ { - "name": "Query Options", - "options": ["--history-length", "--output", "--help"], + "name": "Push Notification Options", + "options": ["--push-url", "--push-token"], }, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler task get": [ + {"name": "Query Options", "options": ["--history-length"]}, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler task notification set": [ + {"name": "Notification Options", "options": ["--url", "--token"]}, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler card get": [ + {"name": "Card Options", "options": ["--authenticated"]}, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler server agent": [ + {"name": "Server Options", "options": ["--host", "--port", "--help"]}, + ], + "handler server push": [ + {"name": "Server Options", "options": ["--host", "--port", "--help"]}, + ], + "handler session clear": [ + {"name": "Clear Options", "options": ["--all", "--help"]}, ], } + click.rich_click.COMMAND_GROUPS = { "handler": [ - { - "name": "Agent Commands", - "commands": ["card", "send", "validate", "tasks"], - }, - { - "name": "Interface Commands", - "commands": ["tui", "server", "webhook"], - }, - { - "name": "Utility Commands", - "commands": ["version", "session"], - }, + {"name": "Agent Communication", "commands": ["message", "task"]}, + {"name": "Agent Discovery", "commands": ["card"]}, + {"name": "Interfaces", "commands": ["tui", "server"]}, + {"name": "Utilities", "commands": ["session", "version"]}, ], - "handler tasks": [ - { - "name": "Task Commands", - "commands": ["get", "cancel", "resubscribe"], - }, - { - "name": "Push Notification Commands", - "commands": ["push-set", "push-get"], - }, + "handler message": [ + {"name": "Message Commands", "commands": ["send", "stream"]}, + ], + "handler task": [ + {"name": "Task Commands", "commands": ["get", "cancel", "resubscribe"]}, + {"name": "Push Notifications", "commands": ["notification"]}, + ], + "handler task notification": [ + {"name": "Notification Commands", "commands": ["set", "get", "list", "delete"]}, + ], + "handler card": [ + {"name": "Card Commands", "commands": ["get", "validate"]}, + ], + "handler server": [ + {"name": "Server Commands", "commands": ["agent", "push"]}, ], "handler session": [ - { - "name": "Session Commands", - "commands": ["list", "show", "clear"], - }, + {"name": "Session Commands", "commands": ["list", "show", "get", "clear"]}, ], } + TIMEOUT = 120 +log = get_logger(__name__) def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: @@ -127,115 +146,182 @@ def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: return httpx.AsyncClient(timeout=timeout) -log = get_logger(__name__) +def _handle_client_error(e: Exception, agent_url: str, ctx: object) -> None: + """Handle A2A client errors with appropriate messages.""" + from a2a_handler.common.output import OutputContext + out = ctx if isinstance(ctx, OutputContext) else None -def _handle_client_error(e: Exception, agent_url: str) -> None: - """Handle A2A client errors with appropriate messages.""" + msg = "" if isinstance(e, A2AClientTimeoutError): log.error("Request to %s timed out", agent_url) - print_error("Request timed out") + msg = "Request timed out" elif isinstance(e, A2AClientHTTPError): log.error("A2A client error: %s", e) - if "connection" in str(e).lower(): - print_error(f"Connection failed: Is the server running at {agent_url}?") - else: - print_error(str(e)) + msg = ( + f"Connection failed: Is the server running at {agent_url}?" + if "connection" in str(e).lower() + else str(e) + ) elif isinstance(e, A2AClientError): log.error("A2A client error: %s", e) - print_error(str(e)) + msg = str(e) elif isinstance(e, httpx.ConnectError): log.error("Connection refused to %s", agent_url) - print_error(f"Connection refused: Is the server running at {agent_url}?") + msg = f"Connection refused: Is the server running at {agent_url}?" elif isinstance(e, httpx.TimeoutException): log.error("Request to %s timed out", agent_url) - print_error("Request timed out") + msg = "Request timed out" elif isinstance(e, httpx.HTTPStatusError): - log.error( - "HTTP error %d from %s: %s", - e.response.status_code, - agent_url, - e.response.text, - ) - print_error(f"HTTP {e.response.status_code} - {e.response.text}") + log.error("HTTP error %d from %s", e.response.status_code, agent_url) + msg = f"HTTP {e.response.status_code} - {e.response.text}" else: log.exception("Failed request to %s", agent_url) - print_error(str(e)) + msg = str(e) + + if out: + out.out_error(msg) + else: + click.echo(f"Error: {msg}", err=True) + + +# ============================================================================ +# Main CLI Group +# ============================================================================ @click.group() -@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging output") -@click.option("--debug", "-d", is_flag=True, help="Enable debug logging output") +@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging") +@click.option("--debug", "-d", is_flag=True, help="Enable debug logging") +@click.option("--raw", "-r", is_flag=True, help="Output raw text without formatting") @click.pass_context -def cli(ctx, verbose: bool, debug: bool) -> None: - """Handler A2A protocol client CLI.""" +def cli(ctx: click.Context, verbose: bool, debug: bool, raw: bool) -> None: + """Handler - A2A protocol client CLI.""" ctx.ensure_object(dict) + ctx.obj["raw"] = raw + if debug: - log.debug("Debug logging enabled") setup_logging(level="DEBUG") elif verbose: - log.debug("Verbose logging enabled") setup_logging(level="INFO") else: setup_logging(level="ERROR") -def _format_send_result(result: SendResult, output: str) -> None: - """Format and display a send result.""" +def get_mode(ctx: click.Context, output: str) -> str: + """Get output mode from context and output option.""" if output == "json": - print_json(json.dumps(result.raw, indent=2)) - return + return "json" + if ctx.obj.get("raw"): + return "raw" + return "text" - content_parts = [] - if result.context_id: - content_parts.append(f"[bold]Context ID:[/bold] [dim]{result.context_id}[/dim]") - if result.task_id: - content_parts.append(f"[bold]Task ID:[/bold] [dim]{result.task_id}[/dim]") - if result.state: - state_color = "green" if result.is_complete else "yellow" - content_parts.append( - f"[bold]State:[/bold] [{state_color}]{result.state.value}[/{state_color}]" - ) +# ============================================================================ +# Message Commands +# ============================================================================ - if content_parts: - console.print("\n".join(content_parts)) - console.print() - if result.text: - print_markdown(result.text, title="Response") - else: - console.print("[dim]No text content in response[/dim]") +@cli.group() +def message() -> None: + """Send messages to A2A agents.""" + pass -def _format_task_result(result: TaskResult, output: str) -> None: - """Format and display a task result.""" - if output == "json": - print_json(json.dumps(result.raw, indent=2)) - return +@message.command("send") +@click.argument("agent_url") +@click.argument("text") +@click.option("--stream", "-s", is_flag=True, help="Stream responses in real-time") +@click.option("--context-id", help="Context ID for conversation continuity") +@click.option("--task-id", help="Task ID to continue") +@click.option( + "--continue", "-C", "use_session", is_flag=True, help="Continue from saved session" +) +@click.option("--push-url", help="Webhook URL for push notifications") +@click.option("--push-token", help="Authentication token for push notifications") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def message_send( + ctx: click.Context, + agent_url: str, + text: str, + stream: bool, + context_id: Optional[str], + task_id: Optional[str], + use_session: bool, + push_url: Optional[str], + push_token: Optional[str], + output: str, +) -> None: + """Send a message to an A2A agent. - state_color = "green" if result.state.value in ("completed",) else "yellow" - if result.state.value in ("failed", "rejected", "canceled"): - state_color = "red" + Send `TEXT` to the agent at `AGENT_URL` and display the response. - content_parts = [ - f"[bold]Task ID:[/bold] [dim]{result.task_id}[/dim]", - f"[bold]State:[/bold] [{state_color}]{result.state.value}[/{state_color}]", - ] + Examples: - if result.context_id: - content_parts.append(f"[bold]Context ID:[/bold] [dim]{result.context_id}[/dim]") + handler message send http://localhost:8000 "Hello, agent!" - title = f"[bold]Task {result.task_id[:8]}...[/bold]" - print_panel("\n".join(content_parts), title=title) + handler message send http://localhost:8000 "Continue our chat" --continue - if result.text: - console.print() - print_markdown(result.text, title="Content") + handler message send http://localhost:8000 "Stream this" --stream + """ + log.info("Sending message to %s", agent_url) + + if use_session and not context_id: + session = get_session(agent_url) + if session.context_id: + context_id = session.context_id + log.info("Using saved context: %s", context_id) + + mode = get_mode(ctx, output) + + async def do_send() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService( + http_client, + agent_url, + enable_streaming=stream, + push_notification_url=push_url, + push_notification_token=push_token, + ) + if mode != "json": + out.out_dim(f"Sending to {agent_url}...") -@cli.command() + if stream: + await _stream_message( + service, text, context_id, task_id, agent_url, out, output + ) + else: + result = await service.send(text, context_id, task_id) + update_session(agent_url, result.context_id, result.task_id) + _format_send_result(result, out, output) + + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() + + asyncio.run(do_send()) + + +@message.command("stream") @click.argument("agent_url") +@click.argument("text") +@click.option("--context-id", help="Context ID for conversation continuity") +@click.option("--task-id", help="Task ID to continue") +@click.option( + "--continue", "-C", "use_session", is_flag=True, help="Continue from saved session" +) +@click.option("--push-url", help="Webhook URL for push notifications") +@click.option("--push-token", help="Authentication token for push notifications") @click.option( "--output", "-o", @@ -243,104 +329,135 @@ def _format_task_result(result: TaskResult, output: str) -> None: default="text", help="Output format", ) -def card(agent_url: str, output: str) -> None: - """Fetch and display an agent card from AGENT_URL.""" - log.info("Fetching agent card from %s", agent_url) +@click.pass_context +def message_stream( + ctx: click.Context, + agent_url: str, + text: str, + context_id: Optional[str], + task_id: Optional[str], + use_session: bool, + push_url: Optional[str], + push_token: Optional[str], + output: str, +) -> None: + """Stream a message response from an A2A agent. - async def fetch() -> None: - try: - log.debug("Building HTTP client") - async with build_http_client() as client: - log.debug("Requesting agent card") - service = A2AService(client, agent_url) - card_data = await service.get_card() - log.info("Retrieved card for agent: %s", card_data.name) - - if output == "json": - log.debug("Outputting card as JSON") - print_json(card_data.model_dump_json(indent=2)) - else: - log.debug("Outputting card as formatted text") - card_dict = card_data.model_dump() - - name = card_dict.pop("name", "Unknown Agent") - description = card_dict.pop("description", "") - - title = f"[bold green]{name}[/bold green] [dim]v{__version__}[/dim]" - content_parts = [] - - if description: - content_parts.append(f"[italic]{description}[/italic]") - - for key, value in card_dict.items(): - if key.startswith("_"): - continue - formatted = format_value(value) - if formatted: - field_name = format_field_name(key) - if "\n" in formatted: - content_parts.append( - f"[bold]{field_name}:[/bold]\n{formatted}" - ) - else: - content_parts.append( - f"[bold]{field_name}:[/bold] {formatted}" - ) + Send `TEXT` to `AGENT_URL` and stream the response in real-time. + + Examples: + + handler message stream http://localhost:8000 "Tell me a story" + """ + ctx.invoke( + message_send, + agent_url=agent_url, + text=text, + stream=True, + context_id=context_id, + task_id=task_id, + use_session=use_session, + push_url=push_url, + push_token=push_token, + output=output, + ) + + +async def _stream_message( + service: A2AService, + text: str, + context_id: Optional[str], + task_id: Optional[str], + agent_url: str, + out: object, + output: str, +) -> None: + """Stream a message and handle events.""" + from a2a_handler.common.output import OutputContext + + out_ctx = out if isinstance(out, OutputContext) else None + collected_text: list[str] = [] + last_context_id: str | None = None + last_task_id: str | None = None + last_state = None + + async for event in service.stream(text, context_id, task_id): + last_context_id = event.context_id or last_context_id + last_task_id = event.task_id or last_task_id + last_state = event.state or last_state + + if output == "json": + event_data = { + "type": event.event_type, + "context_id": event.context_id, + "task_id": event.task_id, + "state": event.state.value if event.state else None, + "text": event.text, + } + if out_ctx: + out_ctx.out_json(event_data) + else: + if event.text and event.text not in collected_text: + if out_ctx: + out_ctx.out_line(event.text) + collected_text.append(event.text) - print_panel("\n\n".join(content_parts), title=title) + update_session(agent_url, last_context_id, last_task_id) - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + if output != "json" and out_ctx: + out_ctx.out_blank() + if last_context_id: + out_ctx.out_field("Context ID", last_context_id, dim_value=True) + if last_task_id: + out_ctx.out_field("Task ID", last_task_id, dim_value=True) + if last_state: + out_ctx.out_state("State", last_state.value) - asyncio.run(fetch()) +def _format_send_result(result: SendResult, out: object, output: str) -> None: + """Format and display a send result.""" + from a2a_handler.common.output import OutputContext + + out_ctx = out if isinstance(out, OutputContext) else None + if not out_ctx: + return -def _format_validation_result(result: ValidationResult, output: str) -> None: - """Format and print validation result.""" if output == "json": - output_data = { - "valid": result.valid, - "source": result.source, - "sourceType": result.source_type.value, - "agentName": result.agent_name, - "protocolVersion": result.protocol_version, - "issues": [ - { - "field": issue.field_name, - "message": issue.message, - "type": issue.issue_type, - } - for issue in result.issues - ], - } - print_json(json.dumps(output_data, indent=2)) + out_ctx.out_json(result.raw) return - if result.valid: - title = "[bold green]✓ Valid Agent Card[/bold green]" - content_parts = [ - f"[bold]Agent:[/bold] {result.agent_name}", - f"[bold]Protocol Version:[/bold] {result.protocol_version}", - f"[bold]Source:[/bold] {result.source}", - ] - print_panel("\n".join(content_parts), title=title) + out_ctx.out_blank() + if result.context_id: + out_ctx.out_field("Context ID", result.context_id, dim_value=True) + if result.task_id: + out_ctx.out_field("Task ID", result.task_id, dim_value=True) + if result.state: + out_ctx.out_state("State", result.state.value) + + out_ctx.out_blank() + if result.text: + out_ctx.out_markdown(result.text) else: - title = "[bold red]✗ Invalid Agent Card[/bold red]" - content_parts = [ - f"[bold]Source:[/bold] {result.source}", - "", - f"[bold red]Errors ({len(result.issues)}):[/bold red]", - ] + out_ctx.out_dim("No text content in response") - for issue in result.issues: - content_parts.append(f" [red]✗[/red] {issue.field_name}: {issue.message}") - print_panel("\n".join(content_parts), title=title) +# ============================================================================ +# Task Commands +# ============================================================================ -@cli.command() -@click.argument("source") +@cli.group() +def task() -> None: + """Manage A2A tasks.""" + pass + + +@task.command("get") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--history-length", "-n", type=int, help="Number of history messages to include" +) @click.option( "--output", "-o", @@ -348,58 +465,87 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: default="text", help="Output format", ) -def validate(source: str, output: str) -> None: - """Validate an agent card from a URL or file path. +@click.pass_context +def task_get( + ctx: click.Context, + agent_url: str, + task_id: str, + history_length: Optional[int], + output: str, +) -> None: + """Get the current state of a task. + + Retrieve `TASK_ID` from the agent at `AGENT_URL`. - SOURCE can be either: - - A URL (e.g., http://localhost:8000) - - A file path (e.g., ./agent-card.json) + Examples: - The command will automatically detect whether the source is a URL or file. + handler task get http://localhost:8000 abc123 """ - log.info("Validating agent card from %s", source) + log.info("Getting task %s from %s", task_id, agent_url) + mode = get_mode(ctx, output) - is_url = source.startswith(("http://", "https://")) + async def do_get() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + result = await service.get_task(task_id, history_length) + _format_task_result(result, out, output) + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() - async def do_validate() -> None: - if is_url: - log.debug("Detected URL source") - async with build_http_client() as client: - result = await validate_agent_card_from_url(source, client) - else: - log.debug("Detected file source") - result = validate_agent_card_from_file(source) + asyncio.run(do_get()) - _format_validation_result(result, output) - if not result.valid: - raise SystemExit(1) +@task.command("cancel") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def task_cancel(ctx: click.Context, agent_url: str, task_id: str, output: str) -> None: + """Cancel a running task. - asyncio.run(do_validate()) + Cancel `TASK_ID` on the agent at `AGENT_URL`. + Examples: -@cli.command() + handler task cancel http://localhost:8000 abc123 + """ + log.info("Canceling task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) + + async def do_cancel() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + + if mode != "json": + out.out_dim(f"Canceling task {task_id}...") + + result = await service.cancel_task(task_id) + _format_task_result(result, out, output) + + if mode != "json": + out.out_success("Task canceled") + + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() + + asyncio.run(do_cancel()) + + +@task.command("resubscribe") @click.argument("agent_url") -@click.argument("message") -@click.option("--stream", "-s", is_flag=True, help="Stream responses in real-time") -@click.option("--context-id", help="Context ID for conversation continuity") -@click.option("--task-id", help="Reference an existing task ID") -@click.option( - "--continue", - "-c", - "use_session", - is_flag=True, - help="Continue last conversation (use saved context_id)", -) -@click.option( - "--push-url", - "-p", - help="Webhook URL to receive push notifications for this task", -) -@click.option( - "--push-token", - help="Optional authentication token for push notifications", -) +@click.argument("task_id") @click.option( "--output", "-o", @@ -407,127 +553,96 @@ async def do_validate() -> None: default="text", help="Output format", ) -def send( - agent_url: str, - message: str, - stream: bool, - context_id: Optional[str], - task_id: Optional[str], - use_session: bool, - push_url: Optional[str], - push_token: Optional[str], - output: str, +@click.pass_context +def task_resubscribe( + ctx: click.Context, agent_url: str, task_id: str, output: str ) -> None: - """Send MESSAGE to an agent at AGENT_URL. + """Resubscribe to a task's event stream. - Use --stream to receive responses in real-time via Server-Sent Events. - Use --continue to automatically use the last context_id from previous conversation. - Use --push-url to configure push notifications for task updates. + Resume streaming `TASK_ID` from `AGENT_URL` after disconnecting. + + Examples: + + handler task resubscribe http://localhost:8000 abc123 """ - log.info("Sending message to %s", agent_url) - log.debug("Message: %s", message[:100] if len(message) > 100 else message) + log.info("Resubscribing to task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) - if use_session and not context_id: - session = get_session(agent_url) - if session.context_id: - context_id = session.context_id - log.info("Using saved context ID: %s", context_id) - - if context_id: - log.debug("Using context ID: %s", context_id) - if task_id: - log.debug("Using task ID: %s", task_id) - - async def send_msg() -> None: - try: - async with build_http_client() as http_client: - service = A2AService( - http_client, - agent_url, - enable_streaming=stream, - push_notification_url=push_url, - push_notification_token=push_token, - ) - - if output == "text": - console.print(f"[dim]Sending message to {agent_url}...[/dim]") - if push_url: - console.print(f"[dim]Push notifications: {push_url}[/dim]") - - if stream: - log.debug("Using streaming mode") - collected_text: list[str] = [] - last_context_id: str | None = None - last_task_id: str | None = None - last_state = None - - async for event in service.stream(message, context_id, task_id): - last_context_id = event.context_id or last_context_id - last_task_id = event.task_id or last_task_id - last_state = event.state or last_state + async def do_resubscribe() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + + if mode != "json": + out.out_dim(f"Resubscribing to task {task_id}...") + async for event in service.resubscribe(task_id): if output == "json": - event_data = { - "type": event.event_type, - "context_id": event.context_id, - "task_id": event.task_id, - "state": event.state.value if event.state else None, - "text": event.text, - } - print_json(json.dumps(event_data)) - else: - if event.text and event.text not in collected_text: - console.print(event.text, end="", markup=False) - collected_text.append(event.text) - - update_session(agent_url, last_context_id, last_task_id) - - if output == "text": - console.print() - console.print() - info_parts = [] - if last_context_id: - info_parts.append( - f"[bold]Context ID:[/bold] [dim]{last_context_id}[/dim]" - ) - if last_task_id: - info_parts.append( - f"[bold]Task ID:[/bold] [dim]{last_task_id}[/dim]" + out.out_json( + { + "type": event.event_type, + "context_id": event.context_id, + "task_id": event.task_id, + "state": event.state.value if event.state else None, + "text": event.text, + } ) - if last_state: - info_parts.append(f"[bold]State:[/bold] {last_state.value}") - if info_parts: - console.print("\n".join(info_parts)) + else: + if event.event_type == "status": + out.out_state( + "Status", + event.state.value if event.state else "unknown", + ) + elif event.text: + out.out_line(event.text) - else: - log.debug("Using non-streaming mode") - result = await service.send(message, context_id, task_id) - update_session(agent_url, result.context_id, result.task_id) - _format_send_result(result, output) + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + asyncio.run(do_resubscribe()) - asyncio.run(send_msg()) +def _format_task_result(result: TaskResult, out: object, output: str) -> None: + """Format and display a task result.""" + from a2a_handler.common.output import OutputContext -@cli.group() -def tasks() -> None: - """Manage A2A tasks.""" + out_ctx = out if isinstance(out, OutputContext) else None + if not out_ctx: + return + + if output == "json": + out_ctx.out_json(result.raw) + return + + out_ctx.out_blank() + out_ctx.out_field("Task ID", result.task_id, dim_value=True) + out_ctx.out_state("State", result.state.value) + if result.context_id: + out_ctx.out_field("Context ID", result.context_id, dim_value=True) + + if result.text: + out_ctx.out_blank() + out_ctx.out_markdown(result.text) + + +# ============================================================================ +# Task Notification Commands +# ============================================================================ + + +@task.group("notification") +def task_notification() -> None: + """Manage push notification configurations for tasks.""" pass -@tasks.command("get") +@task_notification.command("set") @click.argument("agent_url") @click.argument("task_id") -@click.option( - "--history-length", - "-n", - type=int, - default=None, - help="Number of history messages to include", -) +@click.option("--url", "-u", required=True, help="Webhook URL to receive notifications") +@click.option("--token", "-t", help="Authentication token for the webhook") @click.option( "--output", "-o", @@ -535,32 +650,63 @@ def tasks() -> None: default="text", help="Output format", ) -def tasks_get( +@click.pass_context +def notification_set( + ctx: click.Context, agent_url: str, task_id: str, - history_length: Optional[int], + url: str, + token: Optional[str], output: str, ) -> None: - """Get the status of a task by TASK_ID.""" - log.info("Getting task %s from %s", task_id, agent_url) + """Set a push notification webhook for a task. + + Configure `TASK_ID` on `AGENT_URL` to send status updates to a webhook. - async def get_task() -> None: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) - result = await service.get_task(task_id, history_length) - _format_task_result(result, output) + Examples: - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + handler task notification set http://localhost:8000 abc123 --url http://localhost:9000/webhook + """ + log.info("Setting push config for task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) - asyncio.run(get_task()) + async def do_set() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + if mode != "json": + out.out_dim( + f"Setting notification config for task {task_id}..." + ) -@tasks.command("cancel") + config = await service.set_push_config(task_id, url, token) + + if output == "json": + out.out_json(config.model_dump()) + else: + out.out_success("Push notification config set") + out.out_field("Task ID", config.task_id) + if config.push_notification_config: + pnc = config.push_notification_config + out.out_field("URL", pnc.url) + if pnc.token: + out.out_field("Token", f"{pnc.token[:20]}...") + if pnc.id: + out.out_field("Config ID", pnc.id) + + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() + + asyncio.run(do_set()) + + +@task_notification.command("get") @click.argument("agent_url") @click.argument("task_id") +@click.argument("config_id") @click.option( "--output", "-o", @@ -568,36 +714,53 @@ async def get_task() -> None: default="text", help="Output format", ) -def tasks_cancel( +@click.pass_context +def notification_get( + ctx: click.Context, agent_url: str, task_id: str, + config_id: str, output: str, ) -> None: - """Cancel a running task by TASK_ID.""" - log.info("Canceling task %s at %s", task_id, agent_url) + """Get a push notification config by ID. - async def cancel_task() -> None: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) + Retrieve `CONFIG_ID` for `TASK_ID` from `AGENT_URL`. - if output == "text": - console.print(f"[dim]Canceling task {task_id}...[/dim]") + Examples: - result = await service.cancel_task(task_id) - _format_task_result(result, output) + handler task notification get http://localhost:8000 abc123 config456 + """ + log.info("Getting push config %s for task %s at %s", config_id, task_id, agent_url) + mode = get_mode(ctx, output) - if output == "text": - console.print("[green]Task canceled successfully[/green]") + async def do_get() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + config = await service.get_push_config(task_id, config_id) + + if output == "json": + out.out_json(config.model_dump()) + else: + out.out_header("Push Notification Config") + out.out_field("Task ID", config.task_id) + if config.push_notification_config: + pnc = config.push_notification_config + out.out_field("URL", pnc.url) + if pnc.token: + out.out_field("Token", f"{pnc.token[:20]}...") + if pnc.id: + out.out_field("Config ID", pnc.id) - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() - asyncio.run(cancel_task()) + asyncio.run(do_get()) -@tasks.command("resubscribe") +@task_notification.command("list") @click.argument("agent_url") @click.argument("task_id") @click.option( @@ -607,55 +770,59 @@ async def cancel_task() -> None: default="text", help="Output format", ) -def tasks_resubscribe( +@click.pass_context +def notification_list( + ctx: click.Context, agent_url: str, task_id: str, output: str, ) -> None: - """Resubscribe to a task's event stream by TASK_ID. + """List all push notification configs for a task. - This resumes streaming for a task that you previously disconnected from. - """ - log.info("Resubscribing to task %s at %s", task_id, agent_url) + List all configs for `TASK_ID` on `AGENT_URL`. - async def resubscribe() -> None: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) + Examples: - if output == "text": - console.print(f"[dim]Resubscribing to task {task_id}...[/dim]") + handler task notification list http://localhost:8000 abc123 + """ + log.info("Listing push configs for task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) + + async def do_list() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + configs = await service.list_push_configs(task_id) - async for event in service.resubscribe(task_id): if output == "json": - event_data = { - "type": event.event_type, - "context_id": event.context_id, - "task_id": event.task_id, - "state": event.state.value if event.state else None, - "text": event.text, - } - print_json(json.dumps(event_data)) + out.out_json([c.model_dump() for c in configs]) else: - if event.event_type == "status": - console.print( - f"[dim]Status:[/dim] {event.state.value if event.state else 'unknown'}" - ) - elif event.text: - console.print(event.text, markup=False) + if not configs: + out.out_dim("No push notification configs") + return - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + out.out_header(f"Push Notification Configs ({len(configs)})") + for config in configs: + out.out_blank() + out.out_field("Task ID", config.task_id) + if config.push_notification_config: + pnc = config.push_notification_config + out.out_field("URL", pnc.url) + if pnc.id: + out.out_field("Config ID", pnc.id) - asyncio.run(resubscribe()) + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() + asyncio.run(do_list()) -@tasks.command("push-set") + +@task_notification.command("delete") @click.argument("agent_url") @click.argument("task_id") -@click.option("--url", "-u", required=True, help="Webhook URL to receive notifications") -@click.option("--token", "-t", help="Optional authentication token") +@click.argument("config_id") @click.option( "--output", "-o", @@ -663,61 +830,64 @@ async def resubscribe() -> None: default="text", help="Output format", ) -def tasks_push_set( +@click.pass_context +def notification_delete( + ctx: click.Context, agent_url: str, task_id: str, - url: str, - token: Optional[str], + config_id: str, output: str, ) -> None: - """Set push notification config for a task. + """Delete a push notification config. - Configure the agent to send push notifications to a webhook URL - when task status changes. + Delete `CONFIG_ID` for `TASK_ID` on `AGENT_URL`. - Example: - handler tasks push-set http://localhost:8000 TASK_ID --url http://localhost:9000/webhook + Examples: + + handler task notification delete http://localhost:8000 abc123 config456 """ - log.info("Setting push config for task %s at %s", task_id, agent_url) + log.info("Deleting push config %s for task %s at %s", config_id, task_id, agent_url) + mode = get_mode(ctx, output) - async def set_push() -> None: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) + async def do_delete() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) - if output == "text": - console.print( - f"[dim]Setting push notification config for task {task_id}...[/dim]" - ) + if mode != "json": + out.out_dim(f"Deleting config {config_id}...") - config = await service.set_push_config(task_id, url, token) + await service.delete_push_config(task_id, config_id) - if output == "json": - print_json(config.model_dump_json(indent=2)) - else: - console.print( - "[green]Push notification config set successfully[/green]" - ) - console.print(f"[bold]Task ID:[/bold] {config.task_id}") - if config.push_notification_config: - pnc = config.push_notification_config - console.print(f"[bold]URL:[/bold] {pnc.url}") - if pnc.token: - console.print(f"[bold]Token:[/bold] {pnc.token[:20]}...") - if pnc.id: - console.print(f"[bold]Config ID:[/bold] {pnc.id}") + if output == "json": + out.out_json({"deleted": True, "config_id": config_id}) + else: + out.out_success(f"Deleted config {config_id}") - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() - asyncio.run(set_push()) + asyncio.run(do_delete()) -@tasks.command("push-get") +# ============================================================================ +# Card Commands +# ============================================================================ + + +@cli.group() +def card() -> None: + """Agent card operations.""" + pass + + +@card.command("get") @click.argument("agent_url") -@click.argument("task_id") -@click.argument("config_id") +@click.option( + "--authenticated", "-a", is_flag=True, help="Request authenticated extended card" +) @click.option( "--output", "-o", @@ -725,56 +895,212 @@ async def set_push() -> None: default="text", help="Output format", ) -def tasks_push_get( - agent_url: str, - task_id: str, - config_id: str, - output: str, +@click.pass_context +def card_get( + ctx: click.Context, agent_url: str, authenticated: bool, output: str ) -> None: - """Get push notification config for a task.""" - log.info("Getting push config %s for task %s at %s", config_id, task_id, agent_url) + """Fetch and display an agent card. - async def get_push() -> None: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) - config = await service.get_push_config(task_id, config_id) + Retrieve the agent card from `AGENT_URL`. - if output == "json": - print_json(config.model_dump_json(indent=2)) - else: - console.print("[bold]Push Notification Config[/bold]") - console.print(f"[bold]Task ID:[/bold] {config.task_id}") - if config.push_notification_config: - pnc = config.push_notification_config - console.print(f"[bold]URL:[/bold] {pnc.url}") - if pnc.token: - console.print(f"[bold]Token:[/bold] {pnc.token[:20]}...") + Examples: - except Exception as e: - _handle_client_error(e, agent_url) - raise click.Abort() + handler card get http://localhost:8000 - asyncio.run(get_push()) + handler card get http://localhost:8000 --authenticated + """ + log.info("Fetching agent card from %s", agent_url) + mode = get_mode(ctx, output) + async def do_get() -> None: + with get_output_context(mode) as out: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + card_data = await service.get_card() + log.info("Retrieved card for agent: %s", card_data.name) -@cli.command() + if output == "json": + out.out_json(card_data.model_dump()) + else: + _format_agent_card(card_data, out) + + except Exception as e: + _handle_client_error(e, agent_url, out) + raise click.Abort() + + asyncio.run(do_get()) + + +def _format_agent_card(card_data: object, out: object) -> None: + """Format and display an agent card.""" + from typing import Any + + from a2a.types import AgentCard + + from a2a_handler.common.output import OutputContext + + out_ctx = out if isinstance(out, OutputContext) else None + if not out_ctx: + return + + card_dict: dict[str, Any] + if isinstance(card_data, AgentCard): + card_dict = card_data.model_dump() + else: + card_dict = {} + name = card_dict.pop("name", "Unknown Agent") + description = card_dict.pop("description", "") + + out_ctx.out_header(name) + if description: + out_ctx.out_line(description) + + out_ctx.out_blank() + for key, value in card_dict.items(): + if key.startswith("_"): + continue + formatted = format_value(value) + if formatted: + field_name = format_field_name(key) + if "\n" in formatted: + out_ctx.out_line(f"{field_name}:") + out_ctx.out_line(formatted) + else: + out_ctx.out_field(field_name, formatted) + + +@card.command("validate") +@click.argument("source") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def card_validate(ctx: click.Context, source: str, output: str) -> None: + """Validate an agent card from a URL or file path. + + Validate the agent card at `SOURCE` (URL or file path). + + Examples: + + handler card validate http://localhost:8000 + + handler card validate ./my-agent-card.json + """ + log.info("Validating agent card from %s", source) + is_url = source.startswith(("http://", "https://")) + mode = get_mode(ctx, output) + + async def do_validate() -> None: + with get_output_context(mode) as out: + if is_url: + async with build_http_client() as http_client: + result = await validate_agent_card_from_url(source, http_client) + else: + result = validate_agent_card_from_file(source) + + _format_validation_result(result, out, output) + + if not result.valid: + raise SystemExit(1) + + asyncio.run(do_validate()) + + +def _format_validation_result( + result: ValidationResult, out: object, output: str +) -> None: + """Format and display validation result.""" + from a2a_handler.common.output import OutputContext + + out_ctx = out if isinstance(out, OutputContext) else None + if not out_ctx: + return + + if output == "json": + out_ctx.out_json( + { + "valid": result.valid, + "source": result.source, + "sourceType": result.source_type.value, + "agentName": result.agent_name, + "protocolVersion": result.protocol_version, + "issues": [ + {"field": i.field_name, "message": i.message, "type": i.issue_type} + for i in result.issues + ], + } + ) + return + + if result.valid: + out_ctx.out_success("Valid Agent Card") + out_ctx.out_field("Agent", result.agent_name) + out_ctx.out_field("Protocol Version", result.protocol_version) + out_ctx.out_field("Source", result.source) + else: + out_ctx.out_error("Invalid Agent Card") + out_ctx.out_field("Source", result.source) + out_ctx.out_blank() + out_ctx.out_line(f"Errors ({len(result.issues)}):") + for issue in result.issues: + out_ctx.out_list_item(f"{issue.field_name}: {issue.message}", bullet="✗") + + +# ============================================================================ +# Server Commands +# ============================================================================ + + +@cli.group() +def server() -> None: + """Run local servers.""" + pass + + +@server.command("agent") +@click.option("--host", default="0.0.0.0", help="Host to bind to", show_default=True) +@click.option("--port", default=8000, help="Port to bind to", show_default=True) +def server_agent(host: str, port: int) -> None: + """Start the A2A agent server backed by Ollama. + + Examples: + + handler server agent + + handler server agent --port 9000 + """ + log.info("Starting A2A server on %s:%d", host, port) + run_server(host, port) + + +@server.command("push") @click.option("--host", default="127.0.0.1", help="Host to bind to", show_default=True) @click.option("--port", default=9000, help="Port to bind to", show_default=True) -def webhook(host: str, port: int) -> None: - """Start a local webhook server to receive push notifications. +def server_push(host: str, port: int) -> None: + """Start a local webhook server for push notifications. + + Receives and displays push notifications from A2A agents. Useful for testing. + + Examples: - This starts a simple HTTP server that receives and displays - push notifications from A2A agents. Useful for testing. + handler server push - Example: - handler webhook --port 9000 - # Then use http://localhost:9000/webhook as your push notification URL + handler server push --port 9001 """ log.info("Starting webhook server on %s:%d", host, port) run_webhook_server(host, port) +# ============================================================================ +# Session Commands +# ============================================================================ + + @cli.group() def session() -> None: """Manage saved session state.""" @@ -782,78 +1108,135 @@ def session() -> None: @session.command("list") -def session_list() -> None: - """List all saved sessions.""" - store = get_session_store() - sessions = store.list_all() +@click.pass_context +def session_list(ctx: click.Context) -> None: + """List all saved sessions. - if not sessions: - console.print("[dim]No saved sessions[/dim]") - return + Display all agents with saved session state. + + Examples: + + handler session list + """ + mode = "raw" if ctx.obj.get("raw") else "text" - console.print(f"[bold]Saved Sessions ({len(sessions)}):[/bold]") - console.print() - for s in sessions: - console.print(f"[bold cyan]{s.agent_url}[/bold cyan]") - if s.context_id: - console.print(f" [dim]Context ID:[/dim] {s.context_id}") - if s.task_id: - console.print(f" [dim]Task ID:[/dim] {s.task_id}") + with get_output_context(mode) as out: + store = get_session_store() + sessions = store.list_all() + + if not sessions: + out.out_dim("No saved sessions") + return + + out.out_header(f"Saved Sessions ({len(sessions)})") + for s in sessions: + out.out_blank() + out.out_subheader(s.agent_url) + if s.context_id: + out.out_field("Context ID", s.context_id, dim_value=True) + if s.task_id: + out.out_field("Task ID", s.task_id, dim_value=True) @session.command("show") @click.argument("agent_url") -def session_show(agent_url: str) -> None: - """Show session for a specific agent.""" - s = get_session(agent_url) - console.print(f"[bold]Session for {agent_url}[/bold]") - console.print(f"[bold]Context ID:[/bold] {s.context_id or '[dim]none[/dim]'}") - console.print(f"[bold]Task ID:[/bold] {s.task_id or '[dim]none[/dim]'}") +@click.pass_context +def session_show(ctx: click.Context, agent_url: str) -> None: + """Show the saved session for an agent. + + Display the session for `AGENT_URL`. + + Examples: + + handler session show http://localhost:8000 + """ + mode = "raw" if ctx.obj.get("raw") else "text" + + with get_output_context(mode) as out: + s = get_session(agent_url) + out.out_header(f"Session for {agent_url}") + out.out_field("Context ID", s.context_id or "none", dim_value=not s.context_id) + out.out_field("Task ID", s.task_id or "none", dim_value=not s.task_id) + + +@session.command("get") +@click.argument("agent_url") +@click.pass_context +def session_get(ctx: click.Context, agent_url: str) -> None: + """Get a specific session value. + + Retrieve the session for `AGENT_URL`. + + Examples: + + handler session get http://localhost:8000 + """ + ctx.invoke(session_show, agent_url=agent_url) @session.command("clear") @click.argument("agent_url", required=False) @click.option("--all", "-a", "clear_all", is_flag=True, help="Clear all sessions") -def session_clear(agent_url: Optional[str], clear_all: bool) -> None: - """Clear saved session(s). +@click.pass_context +def session_clear( + ctx: click.Context, agent_url: Optional[str], clear_all: bool +) -> None: + """Clear saved sessions. + + Clear the session for `AGENT_URL`, or use `--all` to clear all sessions. - Provide AGENT_URL to clear a specific session, or use --all to clear all. + Examples: + + handler session clear http://localhost:8000 + + handler session clear --all """ - if clear_all: - clear_session() - console.print("[green]Cleared all sessions[/green]") - elif agent_url: - clear_session(agent_url) - console.print(f"[green]Cleared session for {agent_url}[/green]") - else: - console.print( - "[yellow]Provide AGENT_URL or use --all to clear sessions[/yellow]" - ) + mode = "raw" if ctx.obj.get("raw") else "text" + + with get_output_context(mode) as out: + if clear_all: + clear_session() + out.out_success("Cleared all sessions") + elif agent_url: + clear_session(agent_url) + out.out_success(f"Cleared session for {agent_url}") + else: + out.out_warning("Provide AGENT_URL or use --all to clear sessions") -@cli.command() -def tui() -> None: - """Launch the TUI.""" - log.info("Launching TUI") - logging.getLogger().handlers = [] - app = HandlerTUI() - app.run() +# ============================================================================ +# Utility Commands +# ============================================================================ @cli.command() def version() -> None: - """Display the current version.""" - log.debug("Displaying version: %s", __version__) + """Display the current version. + + Examples: + + handler version + """ click.echo(__version__) @cli.command() -@click.option("--host", default="0.0.0.0", help="Host to bind to", show_default=True) -@click.option("--port", default=8000, help="Port to bind to", show_default=True) -def server(host: str, port: int) -> None: - """Start the A2A server agent backed by Ollama.""" - log.info("Starting A2A server on %s:%d", host, port) - run_server(host, port) +def tui() -> None: + """Launch the interactive TUI. + + Examples: + + handler tui + """ + log.info("Launching TUI") + logging.getLogger().handlers = [] + app = HandlerTUI() + app.run() + + +# ============================================================================ +# Entry Point +# ============================================================================ def main() -> None: diff --git a/src/a2a_handler/common/__init__.py b/src/a2a_handler/common/__init__.py index 1d587f8..3838ee6 100644 --- a/src/a2a_handler/common/__init__.py +++ b/src/a2a_handler/common/__init__.py @@ -11,31 +11,21 @@ get_logger, setup_logging, ) -from .printing import ( - BorderStyle, - print_error, - print_info, - print_json, - print_markdown, - print_panel, - print_success, - print_warning, +from .output import ( + OutputContext, + OutputMode, + get_output_context, ) __all__ = [ - "BorderStyle", "HANDLER_THEME", "LogLevel", + "OutputContext", + "OutputMode", "console", "format_field_name", "format_value", "get_logger", - "print_error", - "print_info", - "print_json", - "print_markdown", - "print_panel", - "print_success", - "print_warning", + "get_output_context", "setup_logging", ] diff --git a/src/a2a_handler/common/output.py b/src/a2a_handler/common/output.py new file mode 100644 index 0000000..6a4a613 --- /dev/null +++ b/src/a2a_handler/common/output.py @@ -0,0 +1,177 @@ +"""Output formatting system with mode-aware styling.""" + +from __future__ import annotations + +import json as json_module +import re +from contextlib import contextmanager +from enum import Enum +from typing import Any, Generator + +from rich.console import Console +from rich.markdown import Markdown + +from .logging import console + + +TERMINAL_STATES = {"completed", "failed", "canceled", "rejected"} +SUCCESS_STATES = {"completed"} +ERROR_STATES = {"failed", "rejected"} +WARNING_STATES = {"canceled"} + + +class OutputMode(Enum): + """Output mode for CLI commands.""" + + RAW = "raw" + TEXT = "text" + JSON = "json" + + +def _strip_markup(text: str) -> str: + """Strip Rich markup for raw output.""" + return re.sub(r"\[/?[^\]]+\]", "", text) + + +class OutputContext: + """Manages output mode and styling. + + Provides a unified interface for outputting text, fields, JSON, and + markdown with automatic mode-aware formatting. + """ + + def __init__(self, mode: OutputMode) -> None: + """Initialize output context. + + Args: + mode: Output mode (raw, text, or json) + """ + self.mode = mode + self._raw_console = Console(highlight=False, markup=False) + + def _print(self, text: str, style: str | None = None) -> None: + """Internal print method that respects mode.""" + if self.mode == OutputMode.RAW: + self._raw_console.print(_strip_markup(text)) + elif self.mode == OutputMode.TEXT and style: + console.print(text, style=style) + else: + console.print(text, markup=self.mode == OutputMode.TEXT) + + def out_line(self, text: str, style: str | None = None) -> None: + """Print a line of text with optional style.""" + self._print(text, style) + + def out_field( + self, + name: str, + value: Any, + dim_value: bool = False, + value_style: str | None = None, + ) -> None: + """Print a field as 'Name: value' with formatting.""" + value_str = str(value) if value is not None else "none" + + if self.mode == OutputMode.TEXT: + if value_style: + console.print(f"[bold]{name}:[/bold] [{value_style}]{value_str}[/]") + elif dim_value: + console.print(f"[bold]{name}:[/bold] [dim]{value_str}[/dim]") + else: + console.print(f"[bold]{name}:[/bold] {value_str}") + else: + self._raw_console.print(f"{name}: {_strip_markup(value_str)}") + + def out_header(self, text: str) -> None: + """Print a section header.""" + if self.mode == OutputMode.TEXT: + console.print(f"\n[bold]{text}[/bold]") + else: + self._raw_console.print(f"\n{text}") + + def out_subheader(self, text: str) -> None: + """Print a subheader (less prominent than header).""" + if self.mode == OutputMode.TEXT: + console.print(f"[bold cyan]{text}[/bold cyan]") + else: + self._raw_console.print(text) + + def out_blank(self) -> None: + """Print a blank line.""" + if self.mode == OutputMode.TEXT: + console.print() + else: + self._raw_console.print() + + def out_state(self, name: str, state: str) -> None: + """Print a state field with appropriate coloring.""" + if self.mode == OutputMode.TEXT: + lower = state.lower() + if lower in SUCCESS_STATES: + style = "green" + elif lower in ERROR_STATES: + style = "red" + elif lower in WARNING_STATES: + style = "yellow" + elif lower in TERMINAL_STATES: + style = "bold" + else: + style = "cyan" + console.print(f"[bold]{name}:[/bold] [{style}]{state}[/{style}]") + else: + self._raw_console.print(f"{name}: {state}") + + def out_success(self, text: str) -> None: + """Print a success message.""" + self._print(text, "green") + + def out_error(self, text: str) -> None: + """Print an error message.""" + self._print(text, "red bold") + + def out_warning(self, text: str) -> None: + """Print a warning message.""" + self._print(text, "yellow") + + def out_dim(self, text: str) -> None: + """Print dimmed/muted text.""" + self._print(text, "dim") + + def out_json(self, data: Any) -> None: + """Print JSON data.""" + json_str = json_module.dumps(data, indent=2, default=str) + self._raw_console.print(json_str) + + def out_markdown(self, text: str) -> None: + """Print markdown content.""" + if self.mode == OutputMode.TEXT: + console.print(Markdown(text)) + else: + self._raw_console.print(text) + + def out_list_item(self, text: str, bullet: str = "•") -> None: + """Print a list item with bullet.""" + if self.mode == OutputMode.TEXT: + console.print(f" [dim]{bullet}[/dim] {text}") + else: + self._raw_console.print(f" {bullet} {_strip_markup(text)}") + + +_current_context: OutputContext | None = None + + +@contextmanager +def get_output_context( + mode: OutputMode | str, +) -> Generator[OutputContext, None, None]: + global _current_context + + if isinstance(mode, str): + mode = OutputMode(mode) + + ctx = OutputContext(mode) + _current_context = ctx + try: + yield ctx + finally: + _current_context = None diff --git a/src/a2a_handler/common/printing.py b/src/a2a_handler/common/printing.py deleted file mode 100644 index 06a9c43..0000000 --- a/src/a2a_handler/common/printing.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Unified Rich printing configuration for Handler packages.""" - -from typing import Literal - -from rich.markdown import Markdown -from rich.panel import Panel -from rich.json import JSON - -from .logging import console - -BorderStyle = Literal["green", "blue", "yellow", "red", "cyan", "magenta", "dim"] - - -def print_panel( - content: str, - title: str | None = None, - border_style: BorderStyle = "green", - expand: bool = False, - markdown: bool = False, -) -> None: - """Print content in a Rich panel. - - Args: - content: The content to display - title: Optional panel title - border_style: Border color/style - expand: Whether to expand panel to full width - markdown: Whether to render content as markdown - """ - renderable = Markdown(content) if markdown else content - console.print( - Panel(renderable, title=title, border_style=border_style, expand=expand) - ) - - -def print_info(content: str, title: str | None = None) -> None: - """Print an info panel (cyan border).""" - print_panel(content, title=title, border_style="cyan") - - -def print_success(content: str, title: str | None = None) -> None: - """Print a success panel (green border).""" - print_panel(content, title=title, border_style="green") - - -def print_warning(content: str, title: str | None = None) -> None: - """Print a warning panel (yellow border).""" - print_panel(content, title=title, border_style="yellow") - - -def print_error(content: str, title: str | None = None) -> None: - """Print an error panel (red border).""" - print_panel(content, title=title, border_style="red") - - -def print_json(data: str, title: str | None = None) -> None: - """Print JSON in a panel with structural highlighting. - - Args: - data: JSON string to display - title: Optional panel title - """ - json_renderable = JSON(data, highlight=False) - console.print( - Panel(json_renderable, title=title, border_style="green", expand=False) - ) - - -def print_markdown( - content: str, title: str | None = None, border_style: BorderStyle = "green" -) -> None: - """Print markdown content in a panel. - - Args: - content: Markdown content to display - title: Optional panel title - border_style: Border color/style - """ - print_panel(content, title=title, border_style=border_style, markdown=True) diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py index aa102b8..3cafb6e 100644 --- a/src/a2a_handler/service.py +++ b/src/a2a_handler/service.py @@ -103,7 +103,7 @@ def extract_text_from_message_parts(message_parts: list[Part] | None) -> str: def extract_text_from_task(task: Task) -> str: - """Extract text from task artifacts and history.""" + """Extract text from task artifacts, falling back to history if no artifacts.""" extracted_texts = [] if task.artifacts: @@ -111,7 +111,8 @@ def extract_text_from_task(task: Task) -> str: if artifact.parts: extracted_texts.append(extract_text_from_message_parts(artifact.parts)) - if task.history: + # Only check history if no artifacts found (avoids duplication) + if not extracted_texts and task.history: for message in task.history: if message.role == Role.agent and message.parts: extracted_texts.append(extract_text_from_message_parts(message.parts)) @@ -540,3 +541,77 @@ async def get_push_config( logger.info("Getting push config %s for task %s", config_id, task_id) return await client.get_task_callback(query_params) + + async def list_push_configs( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """List all push notification configurations for a task. + + Args: + task_id: ID of the task + + Returns: + List of push notification configurations + + Note: + This method uses raw JSON-RPC as the SDK doesn't expose this yet. + """ + await self.get_card() + logger.info("Listing push configs for task %s", task_id) + + request_payload = { + "jsonrpc": "2.0", + "method": "tasks/pushNotificationConfig/list", + "params": {"id": task_id}, + "id": str(uuid.uuid4()), + } + + response = await self.http_client.post( + self.agent_url, + json=request_payload, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + result = response.json() + + if "error" in result: + raise RuntimeError(result["error"].get("message", "Unknown error")) + + configs_data = result.get("result", []) + return [TaskPushNotificationConfig.model_validate(c) for c in configs_data] + + async def delete_push_config( + self, + task_id: str, + config_id: str, + ) -> None: + """Delete a push notification configuration for a task. + + Args: + task_id: ID of the task + config_id: ID of the push config to delete + + Note: + This method uses raw JSON-RPC as the SDK doesn't expose this yet. + """ + await self.get_card() + logger.info("Deleting push config %s for task %s", config_id, task_id) + + request_payload = { + "jsonrpc": "2.0", + "method": "tasks/pushNotificationConfig/delete", + "params": {"id": task_id, "push_notification_config_id": config_id}, + "id": str(uuid.uuid4()), + } + + response = await self.http_client.post( + self.agent_url, + json=request_payload, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + result = response.json() + + if "error" in result: + raise RuntimeError(result["error"].get("message", "Unknown error")) From edd6dd44fe1e25dcf166cf39bc6be98bf5a91066 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 19:03:13 -0500 Subject: [PATCH 15/23] chore: add download badge back, the rate limit might've been a one-off issue --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 262faf1..e89bec9 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![A2A Protocol](https://img.shields.io/badge/A2A_Protocol-v0.3.0-blue)](https://a2a-protocol.org/latest/) [![PyPI version](https://img.shields.io/pypi/v/a2a-handler)](https://pypi.org/project/a2a-handler/) [![PyPI - Status](https://img.shields.io/pypi/status/a2a-handler)](https://pypi.org/project/a2a-handler/) +[![PyPI downloads](https://img.shields.io/pypi/dm/a2a-handler)](https://pypi.org/project/a2a-handler/) [![GitHub stars](https://img.shields.io/github/stars/alDuncanson/handler)](https://github.com/alDuncanson/handler/stargazers) An [A2A](https://a2a-protocol.org/latest/) Protocol client TUI and CLI. From 76c4fa8ccbf3357d3a599639652d443a01f43d5a Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 19:03:28 -0500 Subject: [PATCH 16/23] refactor: shorten and standardize CLI command docstrings --- src/a2a_handler/cli.py | 179 +++++------------------------------------ 1 file changed, 19 insertions(+), 160 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 0bf2d76..56bf1a7 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -259,18 +259,7 @@ def message_send( push_token: Optional[str], output: str, ) -> None: - """Send a message to an A2A agent. - - Send `TEXT` to the agent at `AGENT_URL` and display the response. - - Examples: - - handler message send http://localhost:8000 "Hello, agent!" - - handler message send http://localhost:8000 "Continue our chat" --continue - - handler message send http://localhost:8000 "Stream this" --stream - """ + """Send a message to an agent and receive a response.""" log.info("Sending message to %s", agent_url) if use_session and not context_id: @@ -341,14 +330,7 @@ def message_stream( push_token: Optional[str], output: str, ) -> None: - """Stream a message response from an A2A agent. - - Send `TEXT` to `AGENT_URL` and stream the response in real-time. - - Examples: - - handler message stream http://localhost:8000 "Tell me a story" - """ + """Send a message and stream the response in real-time.""" ctx.invoke( message_send, agent_url=agent_url, @@ -473,14 +455,7 @@ def task_get( history_length: Optional[int], output: str, ) -> None: - """Get the current state of a task. - - Retrieve `TASK_ID` from the agent at `AGENT_URL`. - - Examples: - - handler task get http://localhost:8000 abc123 - """ + """Retrieve the current status of a task.""" log.info("Getting task %s from %s", task_id, agent_url) mode = get_mode(ctx, output) @@ -510,14 +485,7 @@ async def do_get() -> None: ) @click.pass_context def task_cancel(ctx: click.Context, agent_url: str, task_id: str, output: str) -> None: - """Cancel a running task. - - Cancel `TASK_ID` on the agent at `AGENT_URL`. - - Examples: - - handler task cancel http://localhost:8000 abc123 - """ + """Request cancellation of a task.""" log.info("Canceling task %s at %s", task_id, agent_url) mode = get_mode(ctx, output) @@ -557,14 +525,7 @@ async def do_cancel() -> None: def task_resubscribe( ctx: click.Context, agent_url: str, task_id: str, output: str ) -> None: - """Resubscribe to a task's event stream. - - Resume streaming `TASK_ID` from `AGENT_URL` after disconnecting. - - Examples: - - handler task resubscribe http://localhost:8000 abc123 - """ + """Resubscribe to a task's SSE stream after disconnection.""" log.info("Resubscribing to task %s at %s", task_id, agent_url) mode = get_mode(ctx, output) @@ -659,14 +620,7 @@ def notification_set( token: Optional[str], output: str, ) -> None: - """Set a push notification webhook for a task. - - Configure `TASK_ID` on `AGENT_URL` to send status updates to a webhook. - - Examples: - - handler task notification set http://localhost:8000 abc123 --url http://localhost:9000/webhook - """ + """Configure a push notification webhook for a task.""" log.info("Setting push config for task %s at %s", task_id, agent_url) mode = get_mode(ctx, output) @@ -722,14 +676,7 @@ def notification_get( config_id: str, output: str, ) -> None: - """Get a push notification config by ID. - - Retrieve `CONFIG_ID` for `TASK_ID` from `AGENT_URL`. - - Examples: - - handler task notification get http://localhost:8000 abc123 config456 - """ + """Retrieve a push notification config by ID.""" log.info("Getting push config %s for task %s at %s", config_id, task_id, agent_url) mode = get_mode(ctx, output) @@ -777,14 +724,7 @@ def notification_list( task_id: str, output: str, ) -> None: - """List all push notification configs for a task. - - List all configs for `TASK_ID` on `AGENT_URL`. - - Examples: - - handler task notification list http://localhost:8000 abc123 - """ + """List all push notification configs for a task.""" log.info("Listing push configs for task %s at %s", task_id, agent_url) mode = get_mode(ctx, output) @@ -838,14 +778,7 @@ def notification_delete( config_id: str, output: str, ) -> None: - """Delete a push notification config. - - Delete `CONFIG_ID` for `TASK_ID` on `AGENT_URL`. - - Examples: - - handler task notification delete http://localhost:8000 abc123 config456 - """ + """Delete a push notification config.""" log.info("Deleting push config %s for task %s at %s", config_id, task_id, agent_url) mode = get_mode(ctx, output) @@ -899,16 +832,7 @@ def card() -> None: def card_get( ctx: click.Context, agent_url: str, authenticated: bool, output: str ) -> None: - """Fetch and display an agent card. - - Retrieve the agent card from `AGENT_URL`. - - Examples: - - handler card get http://localhost:8000 - - handler card get http://localhost:8000 --authenticated - """ + """Retrieve an agent's card.""" log.info("Fetching agent card from %s", agent_url) mode = get_mode(ctx, output) @@ -981,16 +905,7 @@ def _format_agent_card(card_data: object, out: object) -> None: ) @click.pass_context def card_validate(ctx: click.Context, source: str, output: str) -> None: - """Validate an agent card from a URL or file path. - - Validate the agent card at `SOURCE` (URL or file path). - - Examples: - - handler card validate http://localhost:8000 - - handler card validate ./my-agent-card.json - """ + """Validate an agent card from URL or file.""" log.info("Validating agent card from %s", source) is_url = source.startswith(("http://", "https://")) mode = get_mode(ctx, output) @@ -1066,14 +981,7 @@ def server() -> None: @click.option("--host", default="0.0.0.0", help="Host to bind to", show_default=True) @click.option("--port", default=8000, help="Port to bind to", show_default=True) def server_agent(host: str, port: int) -> None: - """Start the A2A agent server backed by Ollama. - - Examples: - - handler server agent - - handler server agent --port 9000 - """ + """Start a local A2A agent server.""" log.info("Starting A2A server on %s:%d", host, port) run_server(host, port) @@ -1082,16 +990,7 @@ def server_agent(host: str, port: int) -> None: @click.option("--host", default="127.0.0.1", help="Host to bind to", show_default=True) @click.option("--port", default=9000, help="Port to bind to", show_default=True) def server_push(host: str, port: int) -> None: - """Start a local webhook server for push notifications. - - Receives and displays push notifications from A2A agents. Useful for testing. - - Examples: - - handler server push - - handler server push --port 9001 - """ + """Start a local webhook server for receiving push notifications.""" log.info("Starting webhook server on %s:%d", host, port) run_webhook_server(host, port) @@ -1110,14 +1009,7 @@ def session() -> None: @session.command("list") @click.pass_context def session_list(ctx: click.Context) -> None: - """List all saved sessions. - - Display all agents with saved session state. - - Examples: - - handler session list - """ + """List all saved sessions.""" mode = "raw" if ctx.obj.get("raw") else "text" with get_output_context(mode) as out: @@ -1142,14 +1034,7 @@ def session_list(ctx: click.Context) -> None: @click.argument("agent_url") @click.pass_context def session_show(ctx: click.Context, agent_url: str) -> None: - """Show the saved session for an agent. - - Display the session for `AGENT_URL`. - - Examples: - - handler session show http://localhost:8000 - """ + """Display session state for an agent.""" mode = "raw" if ctx.obj.get("raw") else "text" with get_output_context(mode) as out: @@ -1163,14 +1048,7 @@ def session_show(ctx: click.Context, agent_url: str) -> None: @click.argument("agent_url") @click.pass_context def session_get(ctx: click.Context, agent_url: str) -> None: - """Get a specific session value. - - Retrieve the session for `AGENT_URL`. - - Examples: - - handler session get http://localhost:8000 - """ + """Get session state for an agent (alias for show).""" ctx.invoke(session_show, agent_url=agent_url) @@ -1181,16 +1059,7 @@ def session_get(ctx: click.Context, agent_url: str) -> None: def session_clear( ctx: click.Context, agent_url: Optional[str], clear_all: bool ) -> None: - """Clear saved sessions. - - Clear the session for `AGENT_URL`, or use `--all` to clear all sessions. - - Examples: - - handler session clear http://localhost:8000 - - handler session clear --all - """ + """Clear saved session state.""" mode = "raw" if ctx.obj.get("raw") else "text" with get_output_context(mode) as out: @@ -1211,23 +1080,13 @@ def session_clear( @cli.command() def version() -> None: - """Display the current version. - - Examples: - - handler version - """ + """Display the current version.""" click.echo(__version__) @cli.command() def tui() -> None: - """Launch the interactive TUI. - - Examples: - - handler tui - """ + """Launch the interactive terminal interface.""" log.info("Launching TUI") logging.getLogger().handlers = [] app = HandlerTUI() From 22f80c13c7860046ddcc16adb577a4f364e27a2b Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 20:16:34 -0500 Subject: [PATCH 17/23] refact: remove raw json rpc attempted methods --- src/a2a_handler/cli.py | 102 +------------------------------------ src/a2a_handler/service.py | 74 --------------------------- 2 files changed, 2 insertions(+), 174 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 56bf1a7..9834fbb 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -3,7 +3,7 @@ Command structure based on A2A protocol method mapping: - message send/stream: Send messages to agents - task get/cancel/resubscribe: Manage tasks -- task notification set/get/list/delete: Push notification configs +- task notification set/get: Push notification config - card get/validate: Agent card operations - server agent/push: Run local servers - session list/show/get/clear: Manage saved sessions @@ -123,7 +123,7 @@ {"name": "Push Notifications", "commands": ["notification"]}, ], "handler task notification": [ - {"name": "Notification Commands", "commands": ["set", "get", "list", "delete"]}, + {"name": "Notification Commands", "commands": ["set", "get"]}, ], "handler card": [ {"name": "Card Commands", "commands": ["get", "validate"]}, @@ -707,104 +707,6 @@ async def do_get() -> None: asyncio.run(do_get()) -@task_notification.command("list") -@click.argument("agent_url") -@click.argument("task_id") -@click.option( - "--output", - "-o", - type=click.Choice(["json", "text"]), - default="text", - help="Output format", -) -@click.pass_context -def notification_list( - ctx: click.Context, - agent_url: str, - task_id: str, - output: str, -) -> None: - """List all push notification configs for a task.""" - log.info("Listing push configs for task %s at %s", task_id, agent_url) - mode = get_mode(ctx, output) - - async def do_list() -> None: - with get_output_context(mode) as out: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) - configs = await service.list_push_configs(task_id) - - if output == "json": - out.out_json([c.model_dump() for c in configs]) - else: - if not configs: - out.out_dim("No push notification configs") - return - - out.out_header(f"Push Notification Configs ({len(configs)})") - for config in configs: - out.out_blank() - out.out_field("Task ID", config.task_id) - if config.push_notification_config: - pnc = config.push_notification_config - out.out_field("URL", pnc.url) - if pnc.id: - out.out_field("Config ID", pnc.id) - - except Exception as e: - _handle_client_error(e, agent_url, out) - raise click.Abort() - - asyncio.run(do_list()) - - -@task_notification.command("delete") -@click.argument("agent_url") -@click.argument("task_id") -@click.argument("config_id") -@click.option( - "--output", - "-o", - type=click.Choice(["json", "text"]), - default="text", - help="Output format", -) -@click.pass_context -def notification_delete( - ctx: click.Context, - agent_url: str, - task_id: str, - config_id: str, - output: str, -) -> None: - """Delete a push notification config.""" - log.info("Deleting push config %s for task %s at %s", config_id, task_id, agent_url) - mode = get_mode(ctx, output) - - async def do_delete() -> None: - with get_output_context(mode) as out: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) - - if mode != "json": - out.out_dim(f"Deleting config {config_id}...") - - await service.delete_push_config(task_id, config_id) - - if output == "json": - out.out_json({"deleted": True, "config_id": config_id}) - else: - out.out_success(f"Deleted config {config_id}") - - except Exception as e: - _handle_client_error(e, agent_url, out) - raise click.Abort() - - asyncio.run(do_delete()) - - # ============================================================================ # Card Commands # ============================================================================ diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py index 3cafb6e..f643d8a 100644 --- a/src/a2a_handler/service.py +++ b/src/a2a_handler/service.py @@ -541,77 +541,3 @@ async def get_push_config( logger.info("Getting push config %s for task %s", config_id, task_id) return await client.get_task_callback(query_params) - - async def list_push_configs( - self, - task_id: str, - ) -> list[TaskPushNotificationConfig]: - """List all push notification configurations for a task. - - Args: - task_id: ID of the task - - Returns: - List of push notification configurations - - Note: - This method uses raw JSON-RPC as the SDK doesn't expose this yet. - """ - await self.get_card() - logger.info("Listing push configs for task %s", task_id) - - request_payload = { - "jsonrpc": "2.0", - "method": "tasks/pushNotificationConfig/list", - "params": {"id": task_id}, - "id": str(uuid.uuid4()), - } - - response = await self.http_client.post( - self.agent_url, - json=request_payload, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - result = response.json() - - if "error" in result: - raise RuntimeError(result["error"].get("message", "Unknown error")) - - configs_data = result.get("result", []) - return [TaskPushNotificationConfig.model_validate(c) for c in configs_data] - - async def delete_push_config( - self, - task_id: str, - config_id: str, - ) -> None: - """Delete a push notification configuration for a task. - - Args: - task_id: ID of the task - config_id: ID of the push config to delete - - Note: - This method uses raw JSON-RPC as the SDK doesn't expose this yet. - """ - await self.get_card() - logger.info("Deleting push config %s for task %s", config_id, task_id) - - request_payload = { - "jsonrpc": "2.0", - "method": "tasks/pushNotificationConfig/delete", - "params": {"id": task_id, "push_notification_config_id": config_id}, - "id": str(uuid.uuid4()), - } - - response = await self.http_client.post( - self.agent_url, - json=request_payload, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - result = response.json() - - if "error" in result: - raise RuntimeError(result["error"].get("message", "Unknown error")) From 3eb89df67ba3f5e8e30bf096e62e0f76f0eec2b4 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 20:31:22 -0500 Subject: [PATCH 18/23] fix: remove notification get command for now as its redundant --- src/a2a_handler/cli.py | 54 ++------------------------------------ src/a2a_handler/service.py | 25 ------------------ 2 files changed, 2 insertions(+), 77 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 9834fbb..5a346a0 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -3,7 +3,7 @@ Command structure based on A2A protocol method mapping: - message send/stream: Send messages to agents - task get/cancel/resubscribe: Manage tasks -- task notification set/get: Push notification config +- task notification set: Push notification config - card get/validate: Agent card operations - server agent/push: Run local servers - session list/show/get/clear: Manage saved sessions @@ -123,7 +123,7 @@ {"name": "Push Notifications", "commands": ["notification"]}, ], "handler task notification": [ - {"name": "Notification Commands", "commands": ["set", "get"]}, + {"name": "Notification Commands", "commands": ["set"]}, ], "handler card": [ {"name": "Card Commands", "commands": ["get", "validate"]}, @@ -657,56 +657,6 @@ async def do_set() -> None: asyncio.run(do_set()) -@task_notification.command("get") -@click.argument("agent_url") -@click.argument("task_id") -@click.argument("config_id") -@click.option( - "--output", - "-o", - type=click.Choice(["json", "text"]), - default="text", - help="Output format", -) -@click.pass_context -def notification_get( - ctx: click.Context, - agent_url: str, - task_id: str, - config_id: str, - output: str, -) -> None: - """Retrieve a push notification config by ID.""" - log.info("Getting push config %s for task %s at %s", config_id, task_id, agent_url) - mode = get_mode(ctx, output) - - async def do_get() -> None: - with get_output_context(mode) as out: - try: - async with build_http_client() as http_client: - service = A2AService(http_client, agent_url) - config = await service.get_push_config(task_id, config_id) - - if output == "json": - out.out_json(config.model_dump()) - else: - out.out_header("Push Notification Config") - out.out_field("Task ID", config.task_id) - if config.push_notification_config: - pnc = config.push_notification_config - out.out_field("URL", pnc.url) - if pnc.token: - out.out_field("Token", f"{pnc.token[:20]}...") - if pnc.id: - out.out_field("Config ID", pnc.id) - - except Exception as e: - _handle_client_error(e, agent_url, out) - raise click.Abort() - - asyncio.run(do_get()) - - # ============================================================================ # Card Commands # ============================================================================ diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py index f643d8a..5ec6e26 100644 --- a/src/a2a_handler/service.py +++ b/src/a2a_handler/service.py @@ -11,7 +11,6 @@ from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory from a2a.types import ( AgentCard, - GetTaskPushNotificationConfigParams, Message, Part, PushNotificationConfig, @@ -517,27 +516,3 @@ async def set_push_config( logger.info("Setting push config for task %s: %s", task_id, webhook_url) return await client.set_task_callback(push_config) - - async def get_push_config( - self, - task_id: str, - config_id: str, - ) -> TaskPushNotificationConfig: - """Get push notification configuration for a task. - - Args: - task_id: ID of the task - config_id: ID of the push config - - Returns: - The push notification configuration - """ - client = await self._get_or_create_client() - - query_params = GetTaskPushNotificationConfigParams( - id=task_id, - push_notification_config_id=config_id, - ) - logger.info("Getting push config %s for task %s", config_id, task_id) - - return await client.get_task_callback(query_params) From fe6605b6eed999fb37f1c564ef3a842a2ced119e Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 20:50:51 -0500 Subject: [PATCH 19/23] refactor: remove session get command --- src/a2a_handler/cli.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 5a346a0..6b444f8 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -6,7 +6,7 @@ - task notification set: Push notification config - card get/validate: Agent card operations - server agent/push: Run local servers -- session list/show/get/clear: Manage saved sessions +- session list/show/clear: Manage saved sessions """ import asyncio @@ -132,7 +132,7 @@ {"name": "Server Commands", "commands": ["agent", "push"]}, ], "handler session": [ - {"name": "Session Commands", "commands": ["list", "show", "get", "clear"]}, + {"name": "Session Commands", "commands": ["list", "show", "clear"]}, ], } @@ -896,14 +896,6 @@ def session_show(ctx: click.Context, agent_url: str) -> None: out.out_field("Task ID", s.task_id or "none", dim_value=not s.task_id) -@session.command("get") -@click.argument("agent_url") -@click.pass_context -def session_get(ctx: click.Context, agent_url: str) -> None: - """Get session state for an agent (alias for show).""" - ctx.invoke(session_show, agent_url=agent_url) - - @session.command("clear") @click.argument("agent_url", required=False) @click.option("--all", "-a", "clear_all", is_flag=True, help="Clear all sessions") From 0070dc10733ff7190b9a710f4d8736885e1eed53 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 21:20:38 -0500 Subject: [PATCH 20/23] refactor: rename OutputContext to Output and update API --- src/a2a_handler/cli.py | 216 +++++++++++++---------------- src/a2a_handler/common/__init__.py | 4 +- src/a2a_handler/common/output.py | 38 ++--- 3 files changed, 121 insertions(+), 137 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 6b444f8..8a775d2 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -11,7 +11,7 @@ import asyncio import logging -from typing import Optional +from typing import Any, Optional logging.getLogger().setLevel(logging.WARNING) @@ -22,6 +22,7 @@ A2AClientHTTPError, A2AClientTimeoutError, ) +from a2a.types import AgentCard from a2a_handler import __version__ from a2a_handler.common import ( @@ -31,6 +32,7 @@ get_output_context, setup_logging, ) +from a2a_handler.common.output import Output from a2a_handler.server import run_server from a2a_handler.service import A2AService, SendResult, TaskResult from a2a_handler.session import ( @@ -146,43 +148,41 @@ def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: return httpx.AsyncClient(timeout=timeout) -def _handle_client_error(e: Exception, agent_url: str, ctx: object) -> None: +def _handle_client_error(e: Exception, agent_url: str, context: object) -> None: """Handle A2A client errors with appropriate messages.""" - from a2a_handler.common.output import OutputContext + output = context if isinstance(context, Output) else None - out = ctx if isinstance(ctx, OutputContext) else None - - msg = "" + message = "" if isinstance(e, A2AClientTimeoutError): log.error("Request to %s timed out", agent_url) - msg = "Request timed out" + message = "Request timed out" elif isinstance(e, A2AClientHTTPError): log.error("A2A client error: %s", e) - msg = ( + message = ( f"Connection failed: Is the server running at {agent_url}?" if "connection" in str(e).lower() else str(e) ) elif isinstance(e, A2AClientError): log.error("A2A client error: %s", e) - msg = str(e) + message = str(e) elif isinstance(e, httpx.ConnectError): log.error("Connection refused to %s", agent_url) - msg = f"Connection refused: Is the server running at {agent_url}?" + message = f"Connection refused: Is the server running at {agent_url}?" elif isinstance(e, httpx.TimeoutException): log.error("Request to %s timed out", agent_url) - msg = "Request timed out" + message = "Request timed out" elif isinstance(e, httpx.HTTPStatusError): log.error("HTTP error %d from %s", e.response.status_code, agent_url) - msg = f"HTTP {e.response.status_code} - {e.response.text}" + message = f"HTTP {e.response.status_code} - {e.response.text}" else: log.exception("Failed request to %s", agent_url) - msg = str(e) + message = str(e) - if out: - out.out_error(msg) + if output: + output.error(message) else: - click.echo(f"Error: {msg}", err=True) + click.echo(f"Error: {message}", err=True) # ============================================================================ @@ -283,7 +283,7 @@ async def do_send() -> None: ) if mode != "json": - out.out_dim(f"Sending to {agent_url}...") + out.dim(f"Sending to {agent_url}...") if stream: await _stream_message( @@ -352,12 +352,10 @@ async def _stream_message( task_id: Optional[str], agent_url: str, out: object, - output: str, + output_format: str, ) -> None: """Stream a message and handle events.""" - from a2a_handler.common.output import OutputContext - - out_ctx = out if isinstance(out, OutputContext) else None + output = out if isinstance(out, Output) else None collected_text: list[str] = [] last_context_id: str | None = None last_task_id: str | None = None @@ -368,7 +366,7 @@ async def _stream_message( last_task_id = event.task_id or last_task_id last_state = event.state or last_state - if output == "json": + if output_format == "json": event_data = { "type": event.event_type, "context_id": event.context_id, @@ -376,51 +374,49 @@ async def _stream_message( "state": event.state.value if event.state else None, "text": event.text, } - if out_ctx: - out_ctx.out_json(event_data) + if output: + output.json(event_data) else: if event.text and event.text not in collected_text: - if out_ctx: - out_ctx.out_line(event.text) + if output: + output.line(event.text) collected_text.append(event.text) update_session(agent_url, last_context_id, last_task_id) - if output != "json" and out_ctx: - out_ctx.out_blank() + if output_format != "json" and output: + output.blank() if last_context_id: - out_ctx.out_field("Context ID", last_context_id, dim_value=True) + output.field("Context ID", last_context_id, dim_value=True) if last_task_id: - out_ctx.out_field("Task ID", last_task_id, dim_value=True) + output.field("Task ID", last_task_id, dim_value=True) if last_state: - out_ctx.out_state("State", last_state.value) + output.state("State", last_state.value) -def _format_send_result(result: SendResult, out: object, output: str) -> None: +def _format_send_result(result: SendResult, out: object, output_format: str) -> None: """Format and display a send result.""" - from a2a_handler.common.output import OutputContext - - out_ctx = out if isinstance(out, OutputContext) else None - if not out_ctx: + output = out if isinstance(out, Output) else None + if not output: return - if output == "json": - out_ctx.out_json(result.raw) + if output_format == "json": + output.json(result.raw) return - out_ctx.out_blank() + output.blank() if result.context_id: - out_ctx.out_field("Context ID", result.context_id, dim_value=True) + output.field("Context ID", result.context_id, dim_value=True) if result.task_id: - out_ctx.out_field("Task ID", result.task_id, dim_value=True) + output.field("Task ID", result.task_id, dim_value=True) if result.state: - out_ctx.out_state("State", result.state.value) + output.state("State", result.state.value) - out_ctx.out_blank() + output.blank() if result.text: - out_ctx.out_markdown(result.text) + output.markdown(result.text) else: - out_ctx.out_dim("No text content in response") + output.dim("No text content in response") # ============================================================================ @@ -496,13 +492,13 @@ async def do_cancel() -> None: service = A2AService(http_client, agent_url) if mode != "json": - out.out_dim(f"Canceling task {task_id}...") + out.dim(f"Canceling task {task_id}...") result = await service.cancel_task(task_id) _format_task_result(result, out, output) if mode != "json": - out.out_success("Task canceled") + out.success("Task canceled") except Exception as e: _handle_client_error(e, agent_url, out) @@ -536,11 +532,11 @@ async def do_resubscribe() -> None: service = A2AService(http_client, agent_url) if mode != "json": - out.out_dim(f"Resubscribing to task {task_id}...") + out.dim(f"Resubscribing to task {task_id}...") async for event in service.resubscribe(task_id): if output == "json": - out.out_json( + out.json( { "type": event.event_type, "context_id": event.context_id, @@ -551,12 +547,12 @@ async def do_resubscribe() -> None: ) else: if event.event_type == "status": - out.out_state( + out.state( "Status", event.state.value if event.state else "unknown", ) elif event.text: - out.out_line(event.text) + out.line(event.text) except Exception as e: _handle_client_error(e, agent_url, out) @@ -565,27 +561,25 @@ async def do_resubscribe() -> None: asyncio.run(do_resubscribe()) -def _format_task_result(result: TaskResult, out: object, output: str) -> None: +def _format_task_result(result: TaskResult, out: object, output_format: str) -> None: """Format and display a task result.""" - from a2a_handler.common.output import OutputContext - - out_ctx = out if isinstance(out, OutputContext) else None - if not out_ctx: + output = out if isinstance(out, Output) else None + if not output: return - if output == "json": - out_ctx.out_json(result.raw) + if output_format == "json": + output.json(result.raw) return - out_ctx.out_blank() - out_ctx.out_field("Task ID", result.task_id, dim_value=True) - out_ctx.out_state("State", result.state.value) + output.blank() + output.field("Task ID", result.task_id, dim_value=True) + output.state("State", result.state.value) if result.context_id: - out_ctx.out_field("Context ID", result.context_id, dim_value=True) + output.field("Context ID", result.context_id, dim_value=True) if result.text: - out_ctx.out_blank() - out_ctx.out_markdown(result.text) + output.blank() + output.markdown(result.text) # ============================================================================ @@ -631,24 +625,22 @@ async def do_set() -> None: service = A2AService(http_client, agent_url) if mode != "json": - out.out_dim( - f"Setting notification config for task {task_id}..." - ) + out.dim(f"Setting notification config for task {task_id}...") config = await service.set_push_config(task_id, url, token) if output == "json": - out.out_json(config.model_dump()) + out.json(config.model_dump()) else: - out.out_success("Push notification config set") - out.out_field("Task ID", config.task_id) + out.success("Push notification config set") + out.field("Task ID", config.task_id) if config.push_notification_config: pnc = config.push_notification_config - out.out_field("URL", pnc.url) + out.field("URL", pnc.url) if pnc.token: - out.out_field("Token", f"{pnc.token[:20]}...") + out.field("Token", f"{pnc.token[:20]}...") if pnc.id: - out.out_field("Config ID", pnc.id) + out.field("Config ID", pnc.id) except Exception as e: _handle_client_error(e, agent_url, out) @@ -697,7 +689,7 @@ async def do_get() -> None: log.info("Retrieved card for agent: %s", card_data.name) if output == "json": - out.out_json(card_data.model_dump()) + out.json(card_data.model_dump()) else: _format_agent_card(card_data, out) @@ -710,14 +702,8 @@ async def do_get() -> None: def _format_agent_card(card_data: object, out: object) -> None: """Format and display an agent card.""" - from typing import Any - - from a2a.types import AgentCard - - from a2a_handler.common.output import OutputContext - - out_ctx = out if isinstance(out, OutputContext) else None - if not out_ctx: + output = out if isinstance(out, Output) else None + if not output: return card_dict: dict[str, Any] @@ -728,11 +714,11 @@ def _format_agent_card(card_data: object, out: object) -> None: name = card_dict.pop("name", "Unknown Agent") description = card_dict.pop("description", "") - out_ctx.out_header(name) + output.header(name) if description: - out_ctx.out_line(description) + output.line(description) - out_ctx.out_blank() + output.blank() for key, value in card_dict.items(): if key.startswith("_"): continue @@ -740,10 +726,10 @@ def _format_agent_card(card_data: object, out: object) -> None: if formatted: field_name = format_field_name(key) if "\n" in formatted: - out_ctx.out_line(f"{field_name}:") - out_ctx.out_line(formatted) + output.line(f"{field_name}:") + output.line(formatted) else: - out_ctx.out_field(field_name, formatted) + output.field(field_name, formatted) @card.command("validate") @@ -779,17 +765,15 @@ async def do_validate() -> None: def _format_validation_result( - result: ValidationResult, out: object, output: str + result: ValidationResult, out: object, output_format: str ) -> None: """Format and display validation result.""" - from a2a_handler.common.output import OutputContext - - out_ctx = out if isinstance(out, OutputContext) else None - if not out_ctx: + output = out if isinstance(out, Output) else None + if not output: return - if output == "json": - out_ctx.out_json( + if output_format == "json": + output.json( { "valid": result.valid, "source": result.source, @@ -805,17 +789,17 @@ def _format_validation_result( return if result.valid: - out_ctx.out_success("Valid Agent Card") - out_ctx.out_field("Agent", result.agent_name) - out_ctx.out_field("Protocol Version", result.protocol_version) - out_ctx.out_field("Source", result.source) + output.success("Valid Agent Card") + output.field("Agent", result.agent_name) + output.field("Protocol Version", result.protocol_version) + output.field("Source", result.source) else: - out_ctx.out_error("Invalid Agent Card") - out_ctx.out_field("Source", result.source) - out_ctx.out_blank() - out_ctx.out_line(f"Errors ({len(result.issues)}):") + output.error("Invalid Agent Card") + output.field("Source", result.source) + output.blank() + output.line(f"Errors ({len(result.issues)}):") for issue in result.issues: - out_ctx.out_list_item(f"{issue.field_name}: {issue.message}", bullet="✗") + output.list_item(f"{issue.field_name}: {issue.message}", bullet="✗") # ============================================================================ @@ -869,17 +853,17 @@ def session_list(ctx: click.Context) -> None: sessions = store.list_all() if not sessions: - out.out_dim("No saved sessions") + out.dim("No saved sessions") return - out.out_header(f"Saved Sessions ({len(sessions)})") + out.header(f"Saved Sessions ({len(sessions)})") for s in sessions: - out.out_blank() - out.out_subheader(s.agent_url) + out.blank() + out.subheader(s.agent_url) if s.context_id: - out.out_field("Context ID", s.context_id, dim_value=True) + out.field("Context ID", s.context_id, dim_value=True) if s.task_id: - out.out_field("Task ID", s.task_id, dim_value=True) + out.field("Task ID", s.task_id, dim_value=True) @session.command("show") @@ -891,9 +875,9 @@ def session_show(ctx: click.Context, agent_url: str) -> None: with get_output_context(mode) as out: s = get_session(agent_url) - out.out_header(f"Session for {agent_url}") - out.out_field("Context ID", s.context_id or "none", dim_value=not s.context_id) - out.out_field("Task ID", s.task_id or "none", dim_value=not s.task_id) + out.header(f"Session for {agent_url}") + out.field("Context ID", s.context_id or "none", dim_value=not s.context_id) + out.field("Task ID", s.task_id or "none", dim_value=not s.task_id) @session.command("clear") @@ -909,12 +893,12 @@ def session_clear( with get_output_context(mode) as out: if clear_all: clear_session() - out.out_success("Cleared all sessions") + out.success("Cleared all sessions") elif agent_url: clear_session(agent_url) - out.out_success(f"Cleared session for {agent_url}") + out.success(f"Cleared session for {agent_url}") else: - out.out_warning("Provide AGENT_URL or use --all to clear sessions") + out.warning("Provide AGENT_URL or use --all to clear sessions") # ============================================================================ diff --git a/src/a2a_handler/common/__init__.py b/src/a2a_handler/common/__init__.py index 3838ee6..72a1fd0 100644 --- a/src/a2a_handler/common/__init__.py +++ b/src/a2a_handler/common/__init__.py @@ -12,7 +12,7 @@ setup_logging, ) from .output import ( - OutputContext, + Output, OutputMode, get_output_context, ) @@ -20,7 +20,7 @@ __all__ = [ "HANDLER_THEME", "LogLevel", - "OutputContext", + "Output", "OutputMode", "console", "format_field_name", diff --git a/src/a2a_handler/common/output.py b/src/a2a_handler/common/output.py index 6a4a613..2a63179 100644 --- a/src/a2a_handler/common/output.py +++ b/src/a2a_handler/common/output.py @@ -33,7 +33,7 @@ def _strip_markup(text: str) -> str: return re.sub(r"\[/?[^\]]+\]", "", text) -class OutputContext: +class Output: """Manages output mode and styling. Provides a unified interface for outputting text, fields, JSON, and @@ -58,11 +58,11 @@ def _print(self, text: str, style: str | None = None) -> None: else: console.print(text, markup=self.mode == OutputMode.TEXT) - def out_line(self, text: str, style: str | None = None) -> None: + def line(self, text: str, style: str | None = None) -> None: """Print a line of text with optional style.""" self._print(text, style) - def out_field( + def field( self, name: str, value: Any, @@ -82,28 +82,28 @@ def out_field( else: self._raw_console.print(f"{name}: {_strip_markup(value_str)}") - def out_header(self, text: str) -> None: + def header(self, text: str) -> None: """Print a section header.""" if self.mode == OutputMode.TEXT: console.print(f"\n[bold]{text}[/bold]") else: self._raw_console.print(f"\n{text}") - def out_subheader(self, text: str) -> None: + def subheader(self, text: str) -> None: """Print a subheader (less prominent than header).""" if self.mode == OutputMode.TEXT: console.print(f"[bold cyan]{text}[/bold cyan]") else: self._raw_console.print(text) - def out_blank(self) -> None: + def blank(self) -> None: """Print a blank line.""" if self.mode == OutputMode.TEXT: console.print() else: self._raw_console.print() - def out_state(self, name: str, state: str) -> None: + def state(self, name: str, state: str) -> None: """Print a state field with appropriate coloring.""" if self.mode == OutputMode.TEXT: lower = state.lower() @@ -121,35 +121,35 @@ def out_state(self, name: str, state: str) -> None: else: self._raw_console.print(f"{name}: {state}") - def out_success(self, text: str) -> None: + def success(self, text: str) -> None: """Print a success message.""" self._print(text, "green") - def out_error(self, text: str) -> None: + def error(self, text: str) -> None: """Print an error message.""" self._print(text, "red bold") - def out_warning(self, text: str) -> None: + def warning(self, text: str) -> None: """Print a warning message.""" self._print(text, "yellow") - def out_dim(self, text: str) -> None: + def dim(self, text: str) -> None: """Print dimmed/muted text.""" self._print(text, "dim") - def out_json(self, data: Any) -> None: + def json(self, data: Any) -> None: """Print JSON data.""" json_str = json_module.dumps(data, indent=2, default=str) self._raw_console.print(json_str) - def out_markdown(self, text: str) -> None: + def markdown(self, text: str) -> None: """Print markdown content.""" if self.mode == OutputMode.TEXT: console.print(Markdown(text)) else: self._raw_console.print(text) - def out_list_item(self, text: str, bullet: str = "•") -> None: + def list_item(self, text: str, bullet: str = "•") -> None: """Print a list item with bullet.""" if self.mode == OutputMode.TEXT: console.print(f" [dim]{bullet}[/dim] {text}") @@ -157,21 +157,21 @@ def out_list_item(self, text: str, bullet: str = "•") -> None: self._raw_console.print(f" {bullet} {_strip_markup(text)}") -_current_context: OutputContext | None = None +_current_context: Output | None = None @contextmanager def get_output_context( mode: OutputMode | str, -) -> Generator[OutputContext, None, None]: +) -> Generator[Output, None, None]: global _current_context if isinstance(mode, str): mode = OutputMode(mode) - ctx = OutputContext(mode) - _current_context = ctx + context = Output(mode) + _current_context = context try: - yield ctx + yield context finally: _current_context = None From 9879f4bc1fbf5bc49e03f84b922fbad1b09e46ce Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 21:25:31 -0500 Subject: [PATCH 21/23] refactor: standardize CLI output handling --- src/a2a_handler/cli.py | 157 ++++++++++++++++++----------------------- 1 file changed, 70 insertions(+), 87 deletions(-) diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 8a775d2..99819a8 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -271,7 +271,7 @@ def message_send( mode = get_mode(ctx, output) async def do_send() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: try: async with build_http_client() as http_client: service = A2AService( @@ -283,19 +283,19 @@ async def do_send() -> None: ) if mode != "json": - out.dim(f"Sending to {agent_url}...") + output.dim(f"Sending to {agent_url}...") if stream: await _stream_message( - service, text, context_id, task_id, agent_url, out, output + service, text, context_id, task_id, agent_url, output ) else: result = await service.send(text, context_id, task_id) update_session(agent_url, result.context_id, result.task_id) - _format_send_result(result, out, output) + _format_send_result(result, output) except Exception as e: - _handle_client_error(e, agent_url, out) + _handle_client_error(e, agent_url, output) raise click.Abort() asyncio.run(do_send()) @@ -351,22 +351,22 @@ async def _stream_message( context_id: Optional[str], task_id: Optional[str], agent_url: str, - out: object, - output_format: str, + output: Output, ) -> None: """Stream a message and handle events.""" - output = out if isinstance(out, Output) else None collected_text: list[str] = [] last_context_id: str | None = None last_task_id: str | None = None last_state = None + is_json = output.mode.value == "json" + async for event in service.stream(text, context_id, task_id): last_context_id = event.context_id or last_context_id last_task_id = event.task_id or last_task_id last_state = event.state or last_state - if output_format == "json": + if is_json: event_data = { "type": event.event_type, "context_id": event.context_id, @@ -374,17 +374,15 @@ async def _stream_message( "state": event.state.value if event.state else None, "text": event.text, } - if output: - output.json(event_data) + output.json(event_data) else: if event.text and event.text not in collected_text: - if output: - output.line(event.text) + output.line(event.text) collected_text.append(event.text) update_session(agent_url, last_context_id, last_task_id) - if output_format != "json" and output: + if not is_json: output.blank() if last_context_id: output.field("Context ID", last_context_id, dim_value=True) @@ -394,13 +392,9 @@ async def _stream_message( output.state("State", last_state.value) -def _format_send_result(result: SendResult, out: object, output_format: str) -> None: +def _format_send_result(result: SendResult, output: Output) -> None: """Format and display a send result.""" - output = out if isinstance(out, Output) else None - if not output: - return - - if output_format == "json": + if output.mode.value == "json": output.json(result.raw) return @@ -456,14 +450,14 @@ def task_get( mode = get_mode(ctx, output) async def do_get() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: try: async with build_http_client() as http_client: service = A2AService(http_client, agent_url) result = await service.get_task(task_id, history_length) - _format_task_result(result, out, output) + _format_task_result(result, output) except Exception as e: - _handle_client_error(e, agent_url, out) + _handle_client_error(e, agent_url, output) raise click.Abort() asyncio.run(do_get()) @@ -486,22 +480,22 @@ def task_cancel(ctx: click.Context, agent_url: str, task_id: str, output: str) - mode = get_mode(ctx, output) async def do_cancel() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: try: async with build_http_client() as http_client: service = A2AService(http_client, agent_url) if mode != "json": - out.dim(f"Canceling task {task_id}...") + output.dim(f"Canceling task {task_id}...") result = await service.cancel_task(task_id) - _format_task_result(result, out, output) + _format_task_result(result, output) if mode != "json": - out.success("Task canceled") + output.success("Task canceled") except Exception as e: - _handle_client_error(e, agent_url, out) + _handle_client_error(e, agent_url, output) raise click.Abort() asyncio.run(do_cancel()) @@ -526,17 +520,18 @@ def task_resubscribe( mode = get_mode(ctx, output) async def do_resubscribe() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: try: async with build_http_client() as http_client: service = A2AService(http_client, agent_url) + is_json = output.mode.value == "json" - if mode != "json": - out.dim(f"Resubscribing to task {task_id}...") + if not is_json: + output.dim(f"Resubscribing to task {task_id}...") async for event in service.resubscribe(task_id): - if output == "json": - out.json( + if is_json: + output.json( { "type": event.event_type, "context_id": event.context_id, @@ -547,27 +542,23 @@ async def do_resubscribe() -> None: ) else: if event.event_type == "status": - out.state( + output.state( "Status", event.state.value if event.state else "unknown", ) elif event.text: - out.line(event.text) + output.line(event.text) except Exception as e: - _handle_client_error(e, agent_url, out) + _handle_client_error(e, agent_url, output) raise click.Abort() asyncio.run(do_resubscribe()) -def _format_task_result(result: TaskResult, out: object, output_format: str) -> None: +def _format_task_result(result: TaskResult, output: Output) -> None: """Format and display a task result.""" - output = out if isinstance(out, Output) else None - if not output: - return - - if output_format == "json": + if output.mode.value == "json": output.json(result.raw) return @@ -619,31 +610,32 @@ def notification_set( mode = get_mode(ctx, output) async def do_set() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: try: async with build_http_client() as http_client: service = A2AService(http_client, agent_url) + is_json = output.mode.value == "json" - if mode != "json": - out.dim(f"Setting notification config for task {task_id}...") + if not is_json: + output.dim(f"Setting notification config for task {task_id}...") config = await service.set_push_config(task_id, url, token) - if output == "json": - out.json(config.model_dump()) + if is_json: + output.json(config.model_dump()) else: - out.success("Push notification config set") - out.field("Task ID", config.task_id) + output.success("Push notification config set") + output.field("Task ID", config.task_id) if config.push_notification_config: pnc = config.push_notification_config - out.field("URL", pnc.url) + output.field("URL", pnc.url) if pnc.token: - out.field("Token", f"{pnc.token[:20]}...") + output.field("Token", f"{pnc.token[:20]}...") if pnc.id: - out.field("Config ID", pnc.id) + output.field("Config ID", pnc.id) except Exception as e: - _handle_client_error(e, agent_url, out) + _handle_client_error(e, agent_url, output) raise click.Abort() asyncio.run(do_set()) @@ -681,30 +673,27 @@ def card_get( mode = get_mode(ctx, output) async def do_get() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: try: async with build_http_client() as http_client: service = A2AService(http_client, agent_url) card_data = await service.get_card() log.info("Retrieved card for agent: %s", card_data.name) - if output == "json": - out.json(card_data.model_dump()) + if output.mode.value == "json": + output.json(card_data.model_dump()) else: - _format_agent_card(card_data, out) + _format_agent_card(card_data, output) except Exception as e: - _handle_client_error(e, agent_url, out) + _handle_client_error(e, agent_url, output) raise click.Abort() asyncio.run(do_get()) -def _format_agent_card(card_data: object, out: object) -> None: +def _format_agent_card(card_data: object, output: Output) -> None: """Format and display an agent card.""" - output = out if isinstance(out, Output) else None - if not output: - return card_dict: dict[str, Any] if isinstance(card_data, AgentCard): @@ -749,14 +738,14 @@ def card_validate(ctx: click.Context, source: str, output: str) -> None: mode = get_mode(ctx, output) async def do_validate() -> None: - with get_output_context(mode) as out: + with get_output_context(mode) as output: if is_url: async with build_http_client() as http_client: result = await validate_agent_card_from_url(source, http_client) else: result = validate_agent_card_from_file(source) - _format_validation_result(result, out, output) + _format_validation_result(result, output) if not result.valid: raise SystemExit(1) @@ -764,15 +753,9 @@ async def do_validate() -> None: asyncio.run(do_validate()) -def _format_validation_result( - result: ValidationResult, out: object, output_format: str -) -> None: +def _format_validation_result(result: ValidationResult, output: Output) -> None: """Format and display validation result.""" - output = out if isinstance(out, Output) else None - if not output: - return - - if output_format == "json": + if output.mode.value == "json": output.json( { "valid": result.valid, @@ -848,22 +831,22 @@ def session_list(ctx: click.Context) -> None: """List all saved sessions.""" mode = "raw" if ctx.obj.get("raw") else "text" - with get_output_context(mode) as out: + with get_output_context(mode) as output: store = get_session_store() sessions = store.list_all() if not sessions: - out.dim("No saved sessions") + output.dim("No saved sessions") return - out.header(f"Saved Sessions ({len(sessions)})") + output.header(f"Saved Sessions ({len(sessions)})") for s in sessions: - out.blank() - out.subheader(s.agent_url) + output.blank() + output.subheader(s.agent_url) if s.context_id: - out.field("Context ID", s.context_id, dim_value=True) + output.field("Context ID", s.context_id, dim_value=True) if s.task_id: - out.field("Task ID", s.task_id, dim_value=True) + output.field("Task ID", s.task_id, dim_value=True) @session.command("show") @@ -873,11 +856,11 @@ def session_show(ctx: click.Context, agent_url: str) -> None: """Display session state for an agent.""" mode = "raw" if ctx.obj.get("raw") else "text" - with get_output_context(mode) as out: + with get_output_context(mode) as output: s = get_session(agent_url) - out.header(f"Session for {agent_url}") - out.field("Context ID", s.context_id or "none", dim_value=not s.context_id) - out.field("Task ID", s.task_id or "none", dim_value=not s.task_id) + output.header(f"Session for {agent_url}") + output.field("Context ID", s.context_id or "none", dim_value=not s.context_id) + output.field("Task ID", s.task_id or "none", dim_value=not s.task_id) @session.command("clear") @@ -890,15 +873,15 @@ def session_clear( """Clear saved session state.""" mode = "raw" if ctx.obj.get("raw") else "text" - with get_output_context(mode) as out: + with get_output_context(mode) as output: if clear_all: clear_session() - out.success("Cleared all sessions") + output.success("Cleared all sessions") elif agent_url: clear_session(agent_url) - out.success(f"Cleared session for {agent_url}") + output.success(f"Cleared session for {agent_url}") else: - out.warning("Provide AGENT_URL or use --all to clear sessions") + output.warning("Provide AGENT_URL or use --all to clear sessions") # ============================================================================ From 359f04ac62620df7193ba26c2318d6c05c34eb13 Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 21:33:25 -0500 Subject: [PATCH 22/23] chore: improve and standardize module docstrings --- src/a2a_handler/__init__.py | 5 ++++- src/a2a_handler/cli.py | 6 +++--- src/a2a_handler/common/__init__.py | 5 ++++- src/a2a_handler/common/formatting.py | 5 ++++- src/a2a_handler/common/logging.py | 5 ++++- src/a2a_handler/common/output.py | 5 ++++- src/a2a_handler/server.py | 5 ++++- src/a2a_handler/service.py | 2 +- src/a2a_handler/session.py | 4 ++-- src/a2a_handler/tui/__init__.py | 5 ++++- src/a2a_handler/tui/app.py | 5 ++++- src/a2a_handler/tui/components/__init__.py | 2 ++ src/a2a_handler/tui/components/card.py | 2 +- src/a2a_handler/tui/components/contact.py | 2 +- src/a2a_handler/tui/components/footer.py | 2 +- src/a2a_handler/tui/components/input.py | 2 +- src/a2a_handler/tui/components/messages.py | 2 +- src/a2a_handler/validation.py | 5 ++++- src/a2a_handler/webhook.py | 5 ++--- tests/__init__.py | 2 +- tests/test_formatting.py | 2 +- tests/test_service.py | 2 +- tests/test_session.py | 2 +- tests/test_tui.py | 2 ++ tests/test_validation.py | 2 +- tests/test_webhook.py | 2 +- 26 files changed, 59 insertions(+), 29 deletions(-) diff --git a/src/a2a_handler/__init__.py b/src/a2a_handler/__init__.py index 5c492a3..ce0ee7c 100644 --- a/src/a2a_handler/__init__.py +++ b/src/a2a_handler/__init__.py @@ -1,4 +1,7 @@ -"""Handler - A2A protocol client and TUI for agent interaction""" +"""Handler - A2A protocol client and TUI for agent interaction. + +Provides CLI and TUI interfaces for communicating with A2A protocol agents. +""" from importlib.metadata import version diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 99819a8..b1cda55 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -1,9 +1,9 @@ -"""Handler CLI - A2A protocol client. +"""Command-line interface for the Handler A2A protocol client. -Command structure based on A2A protocol method mapping: +Provides commands for interacting with A2A agents: - message send/stream: Send messages to agents - task get/cancel/resubscribe: Manage tasks -- task notification set: Push notification config +- task notification set: Configure push notifications - card get/validate: Agent card operations - server agent/push: Run local servers - session list/show/clear: Manage saved sessions diff --git a/src/a2a_handler/common/__init__.py b/src/a2a_handler/common/__init__.py index 72a1fd0..503ad72 100644 --- a/src/a2a_handler/common/__init__.py +++ b/src/a2a_handler/common/__init__.py @@ -1,4 +1,7 @@ -"""Common utilities for Handler.""" +"""Common utilities for the Handler package. + +Provides logging, formatting, and output utilities shared across modules. +""" from .formatting import ( format_field_name, diff --git a/src/a2a_handler/common/formatting.py b/src/a2a_handler/common/formatting.py index 11ac563..cfa959c 100644 --- a/src/a2a_handler/common/formatting.py +++ b/src/a2a_handler/common/formatting.py @@ -1,4 +1,7 @@ -"""Formatting utilities for Handler.""" +"""Formatting utilities for displaying structured data. + +Provides functions for converting field names and values to human-readable format. +""" import re from typing import Any diff --git a/src/a2a_handler/common/logging.py b/src/a2a_handler/common/logging.py index 89d8ff4..9f9d6ea 100644 --- a/src/a2a_handler/common/logging.py +++ b/src/a2a_handler/common/logging.py @@ -1,4 +1,7 @@ -"""Unified Rich logging configuration for Handler packages.""" +"""Rich logging configuration for Handler. + +Provides themed console output and structured logging across all modules. +""" import logging from typing import Literal diff --git a/src/a2a_handler/common/output.py b/src/a2a_handler/common/output.py index 2a63179..4edabfa 100644 --- a/src/a2a_handler/common/output.py +++ b/src/a2a_handler/common/output.py @@ -1,4 +1,7 @@ -"""Output formatting system with mode-aware styling.""" +"""Output formatting system with mode-aware styling. + +Provides a unified output interface supporting raw, text, and JSON modes. +""" from __future__ import annotations diff --git a/src/a2a_handler/server.py b/src/a2a_handler/server.py index f404576..563a8be 100644 --- a/src/a2a_handler/server.py +++ b/src/a2a_handler/server.py @@ -1,4 +1,7 @@ -"""Handler A2A server agent with full push notification support.""" +"""A2A server agent with streaming and push notification support. + +Provides a local A2A-compatible agent server for testing and development. +""" import os from collections.abc import Awaitable, Callable diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py index 5ec6e26..3ba0a4f 100644 --- a/src/a2a_handler/service.py +++ b/src/a2a_handler/service.py @@ -1,6 +1,6 @@ """A2A protocol service layer. -Provides a unified interface for A2A operations, reusable between CLI and TUI. +Provides a unified interface for A2A operations, shared between the CLI and TUI. """ import uuid diff --git a/src/a2a_handler/session.py b/src/a2a_handler/session.py index a8128a9..b8eaaa4 100644 --- a/src/a2a_handler/session.py +++ b/src/a2a_handler/session.py @@ -1,6 +1,6 @@ -"""Session state management for A2A CLI. +"""Session state management for the Handler CLI. -Provides persistence of context_id and task_id across CLI invocations. +Persists context_id and task_id across CLI invocations for conversation continuity. """ import json diff --git a/src/a2a_handler/tui/__init__.py b/src/a2a_handler/tui/__init__.py index 560dc7b..d16e563 100644 --- a/src/a2a_handler/tui/__init__.py +++ b/src/a2a_handler/tui/__init__.py @@ -1,4 +1,7 @@ -"""Handler TUI application.""" +"""Handler TUI application. + +Provides an interactive terminal interface for communicating with A2A agents. +""" from a2a_handler.tui.app import HandlerTUI, main diff --git a/src/a2a_handler/tui/app.py b/src/a2a_handler/tui/app.py index 7e1e5c2..98587cb 100644 --- a/src/a2a_handler/tui/app.py +++ b/src/a2a_handler/tui/app.py @@ -1,4 +1,7 @@ -"""Handler TUI application.""" +"""Main TUI application for Handler. + +Provides the Textual-based terminal interface for agent interaction. +""" import logging import uuid diff --git a/src/a2a_handler/tui/components/__init__.py b/src/a2a_handler/tui/components/__init__.py index 31884df..617ac3c 100644 --- a/src/a2a_handler/tui/components/__init__.py +++ b/src/a2a_handler/tui/components/__init__.py @@ -1,3 +1,5 @@ +"""TUI component widgets for the Handler application.""" + from .card import AgentCardPanel from .contact import ContactPanel from .footer import Footer diff --git a/src/a2a_handler/tui/components/card.py b/src/a2a_handler/tui/components/card.py index a60059f..eb3df37 100644 --- a/src/a2a_handler/tui/components/card.py +++ b/src/a2a_handler/tui/components/card.py @@ -1,4 +1,4 @@ -"""Agent card panel component for displaying agent metadata.""" +"""Agent card panel component for displaying agent metadata and capabilities.""" import json import re diff --git a/src/a2a_handler/tui/components/contact.py b/src/a2a_handler/tui/components/contact.py index da8612c..2f6acdc 100644 --- a/src/a2a_handler/tui/components/contact.py +++ b/src/a2a_handler/tui/components/contact.py @@ -1,4 +1,4 @@ -"""Contact panel component for agent connection management.""" +"""Contact panel component for managing agent connections.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal diff --git a/src/a2a_handler/tui/components/footer.py b/src/a2a_handler/tui/components/footer.py index b7cd5d4..070387c 100644 --- a/src/a2a_handler/tui/components/footer.py +++ b/src/a2a_handler/tui/components/footer.py @@ -1,4 +1,4 @@ -"""Footer component with keyboard shortcut buttons.""" +"""Footer component displaying keyboard shortcut buttons.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal diff --git a/src/a2a_handler/tui/components/input.py b/src/a2a_handler/tui/components/input.py index 0884926..74c5f4d 100644 --- a/src/a2a_handler/tui/components/input.py +++ b/src/a2a_handler/tui/components/input.py @@ -1,4 +1,4 @@ -"""Input panel component for message composition.""" +"""Input panel component for composing and sending messages.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal diff --git a/src/a2a_handler/tui/components/messages.py b/src/a2a_handler/tui/components/messages.py index 1a3527d..0822821 100644 --- a/src/a2a_handler/tui/components/messages.py +++ b/src/a2a_handler/tui/components/messages.py @@ -1,4 +1,4 @@ -"""Messages panel component for chat display.""" +"""Messages panel component for displaying chat history.""" from datetime import datetime from typing import Any diff --git a/src/a2a_handler/validation.py b/src/a2a_handler/validation.py index bc36bdc..7cfa0bb 100644 --- a/src/a2a_handler/validation.py +++ b/src/a2a_handler/validation.py @@ -1,4 +1,7 @@ -"""A2A protocol validation utilities.""" +"""Agent card validation utilities for the A2A protocol. + +Validates agent cards from URLs or local files using the A2A SDK. +""" import json from dataclasses import dataclass, field diff --git a/src/a2a_handler/webhook.py b/src/a2a_handler/webhook.py index 268aabe..cbed26b 100644 --- a/src/a2a_handler/webhook.py +++ b/src/a2a_handler/webhook.py @@ -1,7 +1,6 @@ -"""Local webhook server for receiving A2A push notifications. +"""Webhook server for receiving A2A push notifications. -This module provides a simple HTTP server that can receive push notifications -from A2A agents for testing purposes. +Provides an HTTP server for receiving and displaying push notifications from A2A agents. """ import json diff --git a/tests/__init__.py b/tests/__init__.py index 8b13789..abd12bc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ - +"""Test suite for the Handler package.""" diff --git a/tests/test_formatting.py b/tests/test_formatting.py index b3825c0..9b14ecf 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -1,4 +1,4 @@ -"""Tests for formatting utilities.""" +"""Tests for the formatting utilities module.""" from a2a_handler.common.formatting import format_field_name, format_value diff --git a/tests/test_service.py b/tests/test_service.py index 78f02bb..c0a0f18 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -1,4 +1,4 @@ -"""Tests for A2A service layer.""" +"""Tests for the A2A service layer module.""" from a2a.types import Part, Task, TaskState, TaskStatus, TextPart diff --git a/tests/test_session.py b/tests/test_session.py index 80463d5..7319f66 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,4 @@ -"""Tests for session state management.""" +"""Tests for the session state management module.""" import tempfile from pathlib import Path diff --git a/tests/test_tui.py b/tests/test_tui.py index cd667e3..4177d6f 100644 --- a/tests/test_tui.py +++ b/tests/test_tui.py @@ -1,3 +1,5 @@ +"""Tests for the TUI application.""" + import pytest from a2a_handler.tui import HandlerTUI diff --git a/tests/test_validation.py b/tests/test_validation.py index 275fa53..6c06284 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,4 +1,4 @@ -"""Tests for A2A protocol validation.""" +"""Tests for the agent card validation module.""" import json import tempfile diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 5f0f655..d346a80 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -1,4 +1,4 @@ -"""Tests for webhook server.""" +"""Tests for the webhook server module.""" from datetime import datetime From a94f7cd0942af8610f14181eb82976a8c436a52b Mon Sep 17 00:00:00 2001 From: Al Duncanson Date: Tue, 9 Dec 2025 21:45:07 -0500 Subject: [PATCH 23/23] refactor: Handler agent description and prompt --- src/a2a_handler/server.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/a2a_handler/server.py b/src/a2a_handler/server.py index 563a8be..ea66162 100644 --- a/src/a2a_handler/server.py +++ b/src/a2a_handler/server.py @@ -66,18 +66,25 @@ def create_llm_agent() -> Agent: agent = Agent( name="Handler", model=language_model, - description="Handler assistant", - instruction="""You are Handler, the resident helpful agent for the Handler application. -You are an expert on the Handler toolkit, which is a terminal-based system for communicating with and testing Agent-to-Agent (A2A) protocol agents. -You know that the Handler project consists of: -1. A TUI (Text User Interface) for interactive agent management -2. A CLI (Command Line Interface) for scripting and quick interactions -3. A Client library (packages/client) that implements the A2A protocol -4. A server agent (packages/server) - which is what you are currently running on! - -You should be helpful, friendly, and eager to explain how Handler works. -If asked about installation, usage, or development, provide clear, concise guidance based on the project structure. -You are proud to be an A2A server agent.""", + description="Handler's Agent", + instruction="""You are Handler's Agent, the built-in assistant for the Handler application. + +Handler is an A2A (Agent-to-Agent) protocol client published on PyPI as `a2a-handler`. It provides tools for developers to communicate with, test, and debug A2A-compatible agents. + +Handler's architecture consists of: +1. **TUI** - An interactive terminal interface (Textual-based) for managing agent connections, sending messages, and viewing streaming responses +2. **CLI** - A rich-click powered command-line interface for scripting and automation with commands for: + - `message send/stream` - Send messages to agents with optional streaming + - `task get/cancel/resubscribe` - Manage A2A tasks + - `card get/validate` - Retrieve and validate agent cards + - `session list/show/clear` - Manage conversation sessions + - `server agent/push` - Run local servers (including this one!) +3. **A2AService** - A unified service layer wrapping the a2a-sdk for protocol operations +4. **Server Agent** - A local A2A-compatible agent (you!) for testing, built with Google ADK and LiteLLM/Ollama + +Handler supports streaming responses, push notifications, session persistence, and both JSON and formatted text output. + +You are running as Handler's built-in server agent, useful for testing A2A integrations locally. Be helpful, concise, and knowledgeable about both Handler and the A2A protocol.""", ) logger.info(