diff --git a/AGENTS.md b/AGENTS.md index 1ce6b2d..3e772fc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,70 +1,32 @@ # Agent Development Guide -## Commands -Use `just` for all development tasks: +## Quick Start -| Command | Description | -|---------|-------------| -| `just install` | Install dependencies | -| `just check` | Run lint, format, and typecheck | -| `just fix` | Auto-fix lint/format issues | -| `just test` | Run pytest test suite | -| `just server` | Start A2A server (port 8000) | -| `just tui` | Run TUI application | -| `just tui-dev` | Run TUI with Textual devtools | -| `just web` | Serve TUI as web app | -| `just console` | Run Textual devtools console | -| `just get-card` | Fetch agent card (CLI) | -| `just send` | Send message to agent (CLI) | -| `just validate` | Validate agent card (CLI) | -| `just version` | Show current version | -| `just bump` | Bump version (patch, minor, major) | -| `just tag` | Create git tag for current version | -| `just release` | Tag and push release to origin | +```bash +just install # Install dependencies +just check # Run lint, format, and typecheck +just test # Run tests +``` + +Run `just` to see all available commands. ## Project Structure ``` -handler/ -├── src/a2a_handler/ # Main package -│ ├── _version.py # Version string -│ ├── cli.py # CLI (rich-click) -│ ├── client.py # A2A protocol client (a2a-sdk) -│ ├── validation.py # Agent card validation utilities -│ ├── server.py # A2A server agent (google-adk, litellm) -│ ├── tui.py # TUI application (textual) -│ ├── common/ # Shared utilities (rich, logging) -│ │ ├── logging.py -│ │ └── printing.py -│ └── components/ # TUI components -└── tests/ # pytest tests +src/a2a_handler/ +├── cli.py # CLI entry point +├── client.py # A2A protocol client +├── validation.py # Agent card validation +├── server.py # A2A server agent +├── tui.py # TUI application +├── common/ # Shared utilities +└── components/ # TUI components ``` -## Code Style & Conventions +## Code Style - **Python 3.11+** with full type hints -- **Formatting**: `ruff format` (black compatible) +- **Formatting**: `ruff format` - **Linting**: `ruff check` - **Type Checking**: `ty check` -- **Imports**: Standard → Third-party → Local -- **Testing**: pytest with pytest-asyncio for async tests - -## Environment Variables - -- `OLLAMA_API_BASE`: Ollama server URL (default: `http://localhost:11434`) -- `OLLAMA_MODEL`: Model to use (default: `qwen3`) - -## A2A Protocol - -The `a2a_handler.client` module provides A2A protocol logic: -- `build_http_client()` - Create configured HTTP client -- `fetch_agent_card()` - Retrieve agent metadata -- `send_message_to_agent()` - Send messages and get responses - -## Key Dependencies - -- **CLI**: `rich-click` (enhanced `click` with rich formatting) -- **Client**: `a2a-sdk`, `httpx` -- **Server**: `google-adk`, `litellm`, `uvicorn` -- **TUI**: `textual` -- **Common**: `rich` +- **Testing**: pytest with pytest-asyncio diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d9c1d34..7555103 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,78 +1,27 @@ # Contributing to Handler -## Architecture - -Handler is a single Python package (`a2a-handler`) with all modules under `src/a2a_handler/`: - -| Module | Description | -|--------|-------------| -| `cli.py` | CLI built with `rich-click`. Entry point: `handler` | -| `client.py` | A2A protocol client library using `a2a-sdk` | -| `validation.py` | Agent card validation utilities | -| `common/` | Shared utilities (logging, printing with `rich`) | -| `server.py` | Reference A2A agent using `google-adk` + `litellm` | -| `tui.py` | TUI application built with `textual` | -| `components/` | TUI components | - ## Prerequisites - **Python 3.11+** - **[uv](https://github.com/astral-sh/uv)** for dependency management -- **[just](https://github.com/casey/just)** for running commands (recommended) -- **[Ollama](https://ollama.com/)** for running the reference server agent +- **[just](https://github.com/casey/just)** for running commands +- **[Ollama](https://ollama.com/)** for running the server agent ## Setup ```bash git clone https://github.com/alDuncanson/handler.git cd handler -just install # or: uv sync +just install ``` -## Development Commands +## Development -| Command | Description | -|---------|-------------| -| `just install` | Install dependencies | -| `just check` | Run lint, format, and typecheck | -| `just fix` | Auto-fix lint/format issues | -| `just test` | Run pytest test suite | -| `just server` | Start A2A server on port 8000 | -| `just tui` | Run TUI application | -| `just tui-dev` | Run TUI with Textual devtools | -| `just web` | Serve TUI as web app | -| `just console` | Run Textual devtools console | -| `just get-card [url]` | Fetch agent card from URL | -| `just send [url] [msg]` | Send message to agent | -| `just validate [source]` | Validate agent card from URL or file | -| `just version` | Show current version | -| `just bump [level]` | Bump version (patch, minor, major) | -| `just tag` | Create git tag for current version | -| `just release` | Tag and push release to origin | +Run `just` to see all available commands. ## Code Style -- **Formatting**: `ruff format` (black compatible) +- **Formatting**: `ruff format` - **Linting**: `ruff check` - **Type Checking**: `ty check` -- **Imports**: Standard → Third-party → Local -- **Testing**: Add `pytest` tests for new functionality - -## Environment Variables - -| Variable | Default | Description | -|----------|---------|-------------| -| `OLLAMA_API_BASE` | `http://localhost:11434` | Ollama server URL | -| `OLLAMA_MODEL` | `qwen3` | Model for reference agent | - -## A2A Protocol - -The `a2a_handler.client` module provides the A2A protocol implementation: - -```python -from a2a_handler.client import build_http_client, fetch_agent_card, send_message_to_agent - -async with build_http_client() as client: - card = await fetch_agent_card("http://localhost:8000", client) - response = await send_message_to_agent("http://localhost:8000", "Hello", client) -``` +- **Testing**: pytest diff --git a/README.md b/README.md index 8530728..e89bec9 100644 --- a/README.md +++ b/README.md @@ -27,74 +27,16 @@ uvx --from a2a-handler handler ## Use -Then, you can use Handler: - ```bash -handler +handler --help ``` -If you don't have an A2A server to connect to, Handler provides a local A2A server agent: +To start a local A2A server agent (requires [Ollama](https://ollama.com/)): ```bash handler server ``` -> The server agent requires [Ollama](https://ollama.com/) to be running locally. By default it connects to `http://localhost:11434` and uses the `qwen3` model. -> -> 1. Install and run Ollama -> 2. Pull the model: `ollama pull qwen3` -> 3. (Optional) Configure via environment variables: `OLLAMA_API_BASE` and `OLLAMA_MODEL` - -### TUI - -Interactive terminal user interface: - -```bash -handler tui -``` - -### CLI - -#### Global Options - -```bash -handler --verbose # Enable verbose logging output -handler --debug # Enable debug logging output -handler --help # Show help for any command -``` - -#### Commands - -Fetch agent card from A2A server: - -```bash -handler card http://localhost:8000 -handler card http://localhost:8000 --output json # JSON output -``` - -Validate an agent card from a URL or file: - -```bash -handler validate http://localhost:8000 # Validate from URL -handler validate ./agent-card.json # Validate from file -handler validate http://localhost:8000 --output json # JSON output -``` - -Send a message to an A2A agent: - -```bash -handler send http://localhost:8000 "Hello World" -handler send http://localhost:8000 "Hello" --output json # JSON output -handler send http://localhost:8000 "Hello" --context-id abc # Conversation continuity -handler send http://localhost:8000 "Hello" --task-id xyz # Reference existing task -``` - -Show version: - -```bash -handler version -``` - ## Contributing -See [CONTRIBUTING.md](CONTRIBUTING.md) for architecture and development instructions. +See [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/pyproject.toml b/pyproject.toml index a3305ae..95bf14a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,3 +69,6 @@ build-backend = "uv_build" [tool.uv.build-backend] module-root = "src" + +[tool.ruff.lint.per-file-ignores] +"src/a2a_handler/cli.py" = ["E402"] diff --git a/src/a2a_handler/__init__.py b/src/a2a_handler/__init__.py index 5c492a3..ce0ee7c 100644 --- a/src/a2a_handler/__init__.py +++ b/src/a2a_handler/__init__.py @@ -1,4 +1,7 @@ -"""Handler - A2A protocol client and TUI for agent interaction""" +"""Handler - A2A protocol client and TUI for agent interaction. + +Provides CLI and TUI interfaces for communicating with A2A protocol agents. +""" from importlib.metadata import version diff --git a/src/a2a_handler/cli.py b/src/a2a_handler/cli.py index 1b3facc..b1cda55 100644 --- a/src/a2a_handler/cli.py +++ b/src/a2a_handler/cli.py @@ -1,22 +1,55 @@ +"""Command-line interface for the Handler A2A protocol client. + +Provides commands for interacting with A2A agents: +- message send/stream: Send messages to agents +- task get/cancel/resubscribe: Manage tasks +- task notification set: Configure push notifications +- card get/validate: Agent card operations +- server agent/push: Run local servers +- session list/show/clear: Manage saved sessions +""" + import asyncio -import json import logging from typing import Any, Optional +logging.getLogger().setLevel(logging.WARNING) + import httpx import rich_click as click +from a2a.client.errors import ( + A2AClientError, + A2AClientHTTPError, + A2AClientTimeoutError, +) +from a2a.types import AgentCard from a2a_handler import __version__ from a2a_handler.common import ( - console, + format_field_name, + format_value, get_logger, - print_error, - print_json, - print_markdown, - print_panel, + get_output_context, setup_logging, ) +from a2a_handler.common.output import Output +from a2a_handler.server import run_server +from a2a_handler.service import A2AService, SendResult, TaskResult +from a2a_handler.session import ( + clear_session, + get_session, + get_session_store, + update_session, +) +from a2a_handler.tui import HandlerTUI +from a2a_handler.validation import ( + ValidationResult, + validate_agent_card_from_file, + validate_agent_card_from_url, +) +from a2a_handler.webhook import run_webhook_server +# rich_click configuration click.rich_click.USE_RICH_MARKUP = True click.rich_click.USE_MARKDOWN = True click.rich_click.SHOW_ARGUMENTS = True @@ -26,167 +59,186 @@ click.rich_click.STYLE_ARGUMENT = "cyan" click.rich_click.STYLE_COMMAND = "green" click.rich_click.STYLE_SWITCH = "bold green" + click.rich_click.OPTION_GROUPS = { "handler": [ + {"name": "Global Options", "options": ["--verbose", "--debug", "--help"]}, + {"name": "Output Options", "options": ["--raw"]}, + ], + "handler message send": [ { - "name": "Global Options", - "options": ["--verbose", "--debug", "--help"], + "name": "Message Options", + "options": ["--stream", "--continue", "--context-id", "--task-id"], }, + { + "name": "Push Notification Options", + "options": ["--push-url", "--push-token"], + }, + {"name": "Output Options", "options": ["--output", "--help"]}, ], - "handler send": [ + "handler message stream": [ { "name": "Conversation Options", - "options": ["--context-id", "--task-id"], + "options": ["--continue", "--context-id", "--task-id"], }, { - "name": "Output Options", - "options": ["--output", "--help"], + "name": "Push Notification Options", + "options": ["--push-url", "--push-token"], }, + {"name": "Output Options", "options": ["--output", "--help"]}, ], - "handler server": [ - { - "name": "Server Options", - "options": ["--host", "--port", "--help"], - }, + "handler task get": [ + {"name": "Query Options", "options": ["--history-length"]}, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler task notification set": [ + {"name": "Notification Options", "options": ["--url", "--token"]}, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler card get": [ + {"name": "Card Options", "options": ["--authenticated"]}, + {"name": "Output Options", "options": ["--output", "--help"]}, + ], + "handler server agent": [ + {"name": "Server Options", "options": ["--host", "--port", "--help"]}, + ], + "handler server push": [ + {"name": "Server Options", "options": ["--host", "--port", "--help"]}, + ], + "handler session clear": [ + {"name": "Clear Options", "options": ["--all", "--help"]}, ], } + click.rich_click.COMMAND_GROUPS = { "handler": [ - { - "name": "Agent Commands", - "commands": ["card", "send", "validate"], - }, - { - "name": "Interface Commands", - "commands": ["tui", "server"], - }, - { - "name": "Utility Commands", - "commands": ["version"], - }, + {"name": "Agent Communication", "commands": ["message", "task"]}, + {"name": "Agent Discovery", "commands": ["card"]}, + {"name": "Interfaces", "commands": ["tui", "server"]}, + {"name": "Utilities", "commands": ["session", "version"]}, + ], + "handler message": [ + {"name": "Message Commands", "commands": ["send", "stream"]}, + ], + "handler task": [ + {"name": "Task Commands", "commands": ["get", "cancel", "resubscribe"]}, + {"name": "Push Notifications", "commands": ["notification"]}, + ], + "handler task notification": [ + {"name": "Notification Commands", "commands": ["set"]}, + ], + "handler card": [ + {"name": "Card Commands", "commands": ["get", "validate"]}, + ], + "handler server": [ + {"name": "Server Commands", "commands": ["agent", "push"]}, + ], + "handler session": [ + {"name": "Session Commands", "commands": ["list", "show", "clear"]}, ], } -setup_logging(level="WARNING") -from a2a.client.errors import ( # noqa: E402 - A2AClientError, - A2AClientHTTPError, - A2AClientTimeoutError, -) +TIMEOUT = 120 +log = get_logger(__name__) -from a2a_handler.client import ( # noqa: E402 - build_http_client, - fetch_agent_card, - parse_response, - send_message_to_agent, -) -from a2a_handler.server import run_server # noqa: E402 -from a2a_handler.tui import HandlerTUI # noqa: E402 -from a2a_handler.validation import ( # noqa: E402 - ValidationResult, - validate_agent_card_from_file, - validate_agent_card_from_url, -) -log = get_logger(__name__) +def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: + """Build an HTTP client with the specified timeout.""" + return httpx.AsyncClient(timeout=timeout) + + +def _handle_client_error(e: Exception, agent_url: str, context: object) -> None: + """Handle A2A client errors with appropriate messages.""" + output = context if isinstance(context, Output) else None + + message = "" + if isinstance(e, A2AClientTimeoutError): + log.error("Request to %s timed out", agent_url) + message = "Request timed out" + elif isinstance(e, A2AClientHTTPError): + log.error("A2A client error: %s", e) + message = ( + f"Connection failed: Is the server running at {agent_url}?" + if "connection" in str(e).lower() + else str(e) + ) + elif isinstance(e, A2AClientError): + log.error("A2A client error: %s", e) + message = str(e) + elif isinstance(e, httpx.ConnectError): + log.error("Connection refused to %s", agent_url) + message = f"Connection refused: Is the server running at {agent_url}?" + elif isinstance(e, httpx.TimeoutException): + log.error("Request to %s timed out", agent_url) + message = "Request timed out" + elif isinstance(e, httpx.HTTPStatusError): + log.error("HTTP error %d from %s", e.response.status_code, agent_url) + message = f"HTTP {e.response.status_code} - {e.response.text}" + else: + log.exception("Failed request to %s", agent_url) + message = str(e) + + if output: + output.error(message) + else: + click.echo(f"Error: {message}", err=True) + + +# ============================================================================ +# Main CLI Group +# ============================================================================ @click.group() -@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging output") -@click.option("--debug", "-d", is_flag=True, help="Enable debug logging output") +@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging") +@click.option("--debug", "-d", is_flag=True, help="Enable debug logging") +@click.option("--raw", "-r", is_flag=True, help="Output raw text without formatting") @click.pass_context -def cli(ctx, verbose: bool, debug: bool) -> None: - """Handler A2A protocol client CLI.""" +def cli(ctx: click.Context, verbose: bool, debug: bool, raw: bool) -> None: + """Handler - A2A protocol client CLI.""" ctx.ensure_object(dict) + ctx.obj["raw"] = raw + if debug: - log.debug("Debug logging enabled") setup_logging(level="DEBUG") elif verbose: - log.debug("Verbose logging enabled") setup_logging(level="INFO") else: - setup_logging(level="WARNING") - - -def _format_field_name(name: str) -> str: - """Convert snake_case or camelCase to Title Case.""" - import re - - name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name) - name = name.replace("_", " ") - return name.title() - - -def _format_value(value: Any, indent: int = 0) -> str: - """Recursively format a value for display, returning only truthy content.""" - prefix = " " * indent - - if value is None or value == "" or value == [] or value == {}: - return "" - - if isinstance(value, bool): - return "✓" if value else "✗" - - if isinstance(value, str): - return value - - if isinstance(value, int | float): - return str(value) - - if isinstance(value, list): - lines: list[str] = [] - for item in value: - if hasattr(item, "model_dump"): - item_dict: dict[str, Any] = item.model_dump() - name = item_dict.get("name") or item_dict.get("id") or "Item" - desc = item_dict.get("description") or "" - if desc: - desc_prefix = " " * (indent + 1) - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - lines.append(f"{desc_prefix} {desc}") - else: - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - elif isinstance(item, dict): - item_d: dict[str, Any] = item - name = item_d.get("name") or item_d.get("id") or "Item" - desc = item_d.get("description") or "" - if desc: - desc_prefix = " " * (indent + 1) - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - lines.append(f"{desc_prefix} {desc}") - else: - lines.append(f"{prefix} • [cyan]{name}[/cyan]") - else: - formatted = _format_value(item, indent) - if formatted: - lines.append(f"{prefix} • {formatted}") - return "\n" + "\n".join(lines) if lines else "" - - if hasattr(value, "model_dump"): - value = value.model_dump() - - if isinstance(value, dict): - dict_lines: list[str] = [] - for k, v in value.items(): - if isinstance(k, str) and k.startswith("_"): - continue - formatted = _format_value(v, indent + 1) - if formatted: - field_name = _format_field_name(str(k)) - if "\n" in formatted: - dict_lines.append( - f"{prefix}[bold]{field_name}:[/bold]\n{formatted}" - ) - else: - dict_lines.append(f"{prefix}[bold]{field_name}:[/bold] {formatted}") - return "\n".join(dict_lines) if dict_lines else "" + setup_logging(level="ERROR") - return str(value) if value else "" +def get_mode(ctx: click.Context, output: str) -> str: + """Get output mode from context and output option.""" + if output == "json": + return "json" + if ctx.obj.get("raw"): + return "raw" + return "text" -@cli.command() + +# ============================================================================ +# Message Commands +# ============================================================================ + + +@cli.group() +def message() -> None: + """Send messages to A2A agents.""" + pass + + +@message.command("send") @click.argument("agent_url") +@click.argument("text") +@click.option("--stream", "-s", is_flag=True, help="Stream responses in real-time") +@click.option("--context-id", help="Context ID for conversation continuity") +@click.option("--task-id", help="Task ID to continue") +@click.option( + "--continue", "-C", "use_session", is_flag=True, help="Continue from saved session" +) +@click.option("--push-url", help="Webhook URL for push notifications") +@click.option("--push-token", help="Authentication token for push notifications") @click.option( "--output", "-o", @@ -194,149 +246,190 @@ def _format_value(value: Any, indent: int = 0) -> str: default="text", help="Output format", ) -def card(agent_url: str, output: str) -> None: - """Fetch and display an agent card from AGENT_URL.""" - log.info("Fetching agent card from %s", agent_url) +@click.pass_context +def message_send( + ctx: click.Context, + agent_url: str, + text: str, + stream: bool, + context_id: Optional[str], + task_id: Optional[str], + use_session: bool, + push_url: Optional[str], + push_token: Optional[str], + output: str, +) -> None: + """Send a message to an agent and receive a response.""" + log.info("Sending message to %s", agent_url) - async def fetch() -> None: - try: - log.debug("Building HTTP client") - async with build_http_client() as client: - log.debug("Requesting agent card") - card_data = await fetch_agent_card(agent_url, client) - log.info("Retrieved card for agent: %s", card_data.name) - - if output == "json": - log.debug("Outputting card as JSON") - print_json(card_data.model_dump_json(indent=2)) - else: - log.debug("Outputting card as formatted text") - card_dict = card_data.model_dump() - - name = card_dict.pop("name", "Unknown Agent") - description = card_dict.pop("description", "") - - title = f"[bold green]{name}[/bold green] [dim]v{__version__}[/dim]" - content_parts = [] - - if description: - content_parts.append(f"[italic]{description}[/italic]") - - for key, value in card_dict.items(): - if key.startswith("_"): - continue - formatted = _format_value(value) - if formatted: - field_name = _format_field_name(key) - if "\n" in formatted: - content_parts.append( - f"[bold]{field_name}:[/bold]\n{formatted}" - ) - else: - content_parts.append( - f"[bold]{field_name}:[/bold] {formatted}" - ) + if use_session and not context_id: + session = get_session(agent_url) + if session.context_id: + context_id = session.context_id + log.info("Using saved context: %s", context_id) + + mode = get_mode(ctx, output) + + async def do_send() -> None: + with get_output_context(mode) as output: + try: + async with build_http_client() as http_client: + service = A2AService( + http_client, + agent_url, + enable_streaming=stream, + push_notification_url=push_url, + push_notification_token=push_token, + ) - print_panel("\n\n".join(content_parts), title=title) + if mode != "json": + output.dim(f"Sending to {agent_url}...") - except A2AClientTimeoutError: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except A2AClientHTTPError as e: - log.error("A2A client error: %s", e) - if "connection" in str(e).lower(): - print_error(f"Connection failed: Is the server running at {agent_url}?") - else: - print_error(str(e)) - raise click.Abort() - except A2AClientError as e: - log.error("A2A client error: %s", e) - print_error(str(e)) - raise click.Abort() - except httpx.ConnectError: - log.error("Connection refused to %s", agent_url) - print_error(f"Connection refused: Is the server running at {agent_url}?") - raise click.Abort() - except httpx.TimeoutException: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except httpx.HTTPStatusError as e: - log.error( - "HTTP error %d from %s: %s", - e.response.status_code, - agent_url, - e.response.text, - ) - print_error(f"HTTP {e.response.status_code} - {e.response.text}") - raise click.Abort() - except Exception as e: - log.exception("Failed to fetch agent card from %s", agent_url) - print_error(str(e)) - raise click.Abort() - - asyncio.run(fetch()) - - -def _format_validation_result(result: ValidationResult, output: str) -> None: - """Format and print validation result.""" - if output == "json": - import json - - output_data = { - "valid": result.valid, - "source": result.source, - "sourceType": result.source_type.value, - "agentName": result.agent_name, - "protocolVersion": result.protocol_version, - "issues": [ - {"field": i.field, "message": i.message, "type": i.issue_type} - for i in result.issues - ], - "warnings": [ - {"field": w.field, "message": w.message, "type": w.issue_type} - for w in result.warnings - ], - } - print_json(json.dumps(output_data, indent=2)) + if stream: + await _stream_message( + service, text, context_id, task_id, agent_url, output + ) + else: + result = await service.send(text, context_id, task_id) + update_session(agent_url, result.context_id, result.task_id) + _format_send_result(result, output) + + except Exception as e: + _handle_client_error(e, agent_url, output) + raise click.Abort() + + asyncio.run(do_send()) + + +@message.command("stream") +@click.argument("agent_url") +@click.argument("text") +@click.option("--context-id", help="Context ID for conversation continuity") +@click.option("--task-id", help="Task ID to continue") +@click.option( + "--continue", "-C", "use_session", is_flag=True, help="Continue from saved session" +) +@click.option("--push-url", help="Webhook URL for push notifications") +@click.option("--push-token", help="Authentication token for push notifications") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def message_stream( + ctx: click.Context, + agent_url: str, + text: str, + context_id: Optional[str], + task_id: Optional[str], + use_session: bool, + push_url: Optional[str], + push_token: Optional[str], + output: str, +) -> None: + """Send a message and stream the response in real-time.""" + ctx.invoke( + message_send, + agent_url=agent_url, + text=text, + stream=True, + context_id=context_id, + task_id=task_id, + use_session=use_session, + push_url=push_url, + push_token=push_token, + output=output, + ) + + +async def _stream_message( + service: A2AService, + text: str, + context_id: Optional[str], + task_id: Optional[str], + agent_url: str, + output: Output, +) -> None: + """Stream a message and handle events.""" + collected_text: list[str] = [] + last_context_id: str | None = None + last_task_id: str | None = None + last_state = None + + is_json = output.mode.value == "json" + + async for event in service.stream(text, context_id, task_id): + last_context_id = event.context_id or last_context_id + last_task_id = event.task_id or last_task_id + last_state = event.state or last_state + + if is_json: + event_data = { + "type": event.event_type, + "context_id": event.context_id, + "task_id": event.task_id, + "state": event.state.value if event.state else None, + "text": event.text, + } + output.json(event_data) + else: + if event.text and event.text not in collected_text: + output.line(event.text) + collected_text.append(event.text) + + update_session(agent_url, last_context_id, last_task_id) + + if not is_json: + output.blank() + if last_context_id: + output.field("Context ID", last_context_id, dim_value=True) + if last_task_id: + output.field("Task ID", last_task_id, dim_value=True) + if last_state: + output.state("State", last_state.value) + + +def _format_send_result(result: SendResult, output: Output) -> None: + """Format and display a send result.""" + if output.mode.value == "json": + output.json(result.raw) return - if result.valid: - title = "[bold green]✓ Valid Agent Card[/bold green]" - content_parts = [ - f"[bold]Agent:[/bold] {result.agent_name}", - f"[bold]Protocol Version:[/bold] {result.protocol_version}", - f"[bold]Source:[/bold] {result.source}", - ] - - if result.warnings: - content_parts.append("") - content_parts.append( - f"[bold yellow]Warnings ({len(result.warnings)}):[/bold yellow]" - ) - for warning in result.warnings: - content_parts.append( - f" [yellow]⚠[/yellow] {warning.field}: {warning.message}" - ) - - print_panel("\n".join(content_parts), title=title) + output.blank() + if result.context_id: + output.field("Context ID", result.context_id, dim_value=True) + if result.task_id: + output.field("Task ID", result.task_id, dim_value=True) + if result.state: + output.state("State", result.state.value) + + output.blank() + if result.text: + output.markdown(result.text) else: - title = "[bold red]✗ Invalid Agent Card[/bold red]" - content_parts = [ - f"[bold]Source:[/bold] {result.source}", - "", - f"[bold red]Errors ({len(result.issues)}):[/bold red]", - ] + output.dim("No text content in response") - for issue in result.issues: - content_parts.append(f" [red]✗[/red] {issue.field}: {issue.message}") - print_panel("\n".join(content_parts), title=title) +# ============================================================================ +# Task Commands +# ============================================================================ -@cli.command() -@click.argument("source") +@cli.group() +def task() -> None: + """Manage A2A tasks.""" + pass + + +@task.command("get") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--history-length", "-n", type=int, help="Number of history messages to include" +) @click.option( "--output", "-o", @@ -344,41 +437,73 @@ def _format_validation_result(result: ValidationResult, output: str) -> None: default="text", help="Output format", ) -def validate(source: str, output: str) -> None: - """Validate an agent card from a URL or file path. +@click.pass_context +def task_get( + ctx: click.Context, + agent_url: str, + task_id: str, + history_length: Optional[int], + output: str, +) -> None: + """Retrieve the current status of a task.""" + log.info("Getting task %s from %s", task_id, agent_url) + mode = get_mode(ctx, output) - SOURCE can be either: - - A URL (e.g., http://localhost:8000) - - A file path (e.g., ./agent-card.json) + async def do_get() -> None: + with get_output_context(mode) as output: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + result = await service.get_task(task_id, history_length) + _format_task_result(result, output) + except Exception as e: + _handle_client_error(e, agent_url, output) + raise click.Abort() - The command will automatically detect whether the source is a URL or file. - """ - log.info("Validating agent card from %s", source) + asyncio.run(do_get()) - is_url = source.startswith(("http://", "https://")) - async def do_validate() -> None: - if is_url: - log.debug("Detected URL source") - async with build_http_client() as client: - result = await validate_agent_card_from_url(source, client) - else: - log.debug("Detected file source") - result = validate_agent_card_from_file(source) +@task.command("cancel") +@click.argument("agent_url") +@click.argument("task_id") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def task_cancel(ctx: click.Context, agent_url: str, task_id: str, output: str) -> None: + """Request cancellation of a task.""" + log.info("Canceling task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) - _format_validation_result(result, output) + async def do_cancel() -> None: + with get_output_context(mode) as output: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) - if not result.valid: - raise click.Abort() + if mode != "json": + output.dim(f"Canceling task {task_id}...") - asyncio.run(do_validate()) + result = await service.cancel_task(task_id) + _format_task_result(result, output) + if mode != "json": + output.success("Task canceled") -@cli.command() + except Exception as e: + _handle_client_error(e, agent_url, output) + raise click.Abort() + + asyncio.run(do_cancel()) + + +@task.command("resubscribe") @click.argument("agent_url") -@click.argument("message") -@click.option("--context-id", help="Context ID for conversation continuity") -@click.option("--task-id", help="Reference an existing task ID") +@click.argument("task_id") @click.option( "--output", "-o", @@ -386,112 +511,402 @@ async def do_validate() -> None: default="text", help="Output format", ) -def send( +@click.pass_context +def task_resubscribe( + ctx: click.Context, agent_url: str, task_id: str, output: str +) -> None: + """Resubscribe to a task's SSE stream after disconnection.""" + log.info("Resubscribing to task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) + + async def do_resubscribe() -> None: + with get_output_context(mode) as output: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + is_json = output.mode.value == "json" + + if not is_json: + output.dim(f"Resubscribing to task {task_id}...") + + async for event in service.resubscribe(task_id): + if is_json: + output.json( + { + "type": event.event_type, + "context_id": event.context_id, + "task_id": event.task_id, + "state": event.state.value if event.state else None, + "text": event.text, + } + ) + else: + if event.event_type == "status": + output.state( + "Status", + event.state.value if event.state else "unknown", + ) + elif event.text: + output.line(event.text) + + except Exception as e: + _handle_client_error(e, agent_url, output) + raise click.Abort() + + asyncio.run(do_resubscribe()) + + +def _format_task_result(result: TaskResult, output: Output) -> None: + """Format and display a task result.""" + if output.mode.value == "json": + output.json(result.raw) + return + + output.blank() + output.field("Task ID", result.task_id, dim_value=True) + output.state("State", result.state.value) + if result.context_id: + output.field("Context ID", result.context_id, dim_value=True) + + if result.text: + output.blank() + output.markdown(result.text) + + +# ============================================================================ +# Task Notification Commands +# ============================================================================ + + +@task.group("notification") +def task_notification() -> None: + """Manage push notification configurations for tasks.""" + pass + + +@task_notification.command("set") +@click.argument("agent_url") +@click.argument("task_id") +@click.option("--url", "-u", required=True, help="Webhook URL to receive notifications") +@click.option("--token", "-t", help="Authentication token for the webhook") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def notification_set( + ctx: click.Context, agent_url: str, - message: str, - context_id: Optional[str], - task_id: Optional[str], + task_id: str, + url: str, + token: Optional[str], output: str, ) -> None: - """Send MESSAGE to an agent at AGENT_URL.""" - log.info("Sending message to %s", agent_url) - log.debug("Message: %s", message[:100] if len(message) > 100 else message) - - if context_id: - log.debug("Using context ID: %s", context_id) - if task_id: - log.debug("Using task ID: %s", task_id) - - async def send_msg() -> None: - try: - log.debug("Building HTTP client") - async with build_http_client() as client: - if output == "text": - console.print(f"[dim]Sending message to {agent_url}...[/dim]") - - log.debug("Sending message via A2A client") - response = await send_message_to_agent( - agent_url, message, client, context_id, task_id - ) - log.debug("Received response from agent") - - if output == "json": - log.debug("Outputting response as JSON") - print_json(json.dumps(response, indent=2)) - else: - log.debug("Parsing response for text output") - parsed = parse_response(response) - - if parsed.has_content: - log.debug("Response contains %d characters", len(parsed.text)) - print_markdown(parsed.text, title="Response") + """Configure a push notification webhook for a task.""" + log.info("Setting push config for task %s at %s", task_id, agent_url) + mode = get_mode(ctx, output) + + async def do_set() -> None: + with get_output_context(mode) as output: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + is_json = output.mode.value == "json" + + if not is_json: + output.dim(f"Setting notification config for task {task_id}...") + + config = await service.set_push_config(task_id, url, token) + + if is_json: + output.json(config.model_dump()) else: - log.warning("Response contained no text content") - print_markdown("No text in response", title="Response") - - except A2AClientTimeoutError: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except A2AClientHTTPError as e: - log.error("A2A client error: %s", e) - if "connection" in str(e).lower(): - print_error(f"Connection failed: Is the server running at {agent_url}?") + output.success("Push notification config set") + output.field("Task ID", config.task_id) + if config.push_notification_config: + pnc = config.push_notification_config + output.field("URL", pnc.url) + if pnc.token: + output.field("Token", f"{pnc.token[:20]}...") + if pnc.id: + output.field("Config ID", pnc.id) + + except Exception as e: + _handle_client_error(e, agent_url, output) + raise click.Abort() + + asyncio.run(do_set()) + + +# ============================================================================ +# Card Commands +# ============================================================================ + + +@cli.group() +def card() -> None: + """Agent card operations.""" + pass + + +@card.command("get") +@click.argument("agent_url") +@click.option( + "--authenticated", "-a", is_flag=True, help="Request authenticated extended card" +) +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def card_get( + ctx: click.Context, agent_url: str, authenticated: bool, output: str +) -> None: + """Retrieve an agent's card.""" + log.info("Fetching agent card from %s", agent_url) + mode = get_mode(ctx, output) + + async def do_get() -> None: + with get_output_context(mode) as output: + try: + async with build_http_client() as http_client: + service = A2AService(http_client, agent_url) + card_data = await service.get_card() + log.info("Retrieved card for agent: %s", card_data.name) + + if output.mode.value == "json": + output.json(card_data.model_dump()) + else: + _format_agent_card(card_data, output) + + except Exception as e: + _handle_client_error(e, agent_url, output) + raise click.Abort() + + asyncio.run(do_get()) + + +def _format_agent_card(card_data: object, output: Output) -> None: + """Format and display an agent card.""" + + card_dict: dict[str, Any] + if isinstance(card_data, AgentCard): + card_dict = card_data.model_dump() + else: + card_dict = {} + name = card_dict.pop("name", "Unknown Agent") + description = card_dict.pop("description", "") + + output.header(name) + if description: + output.line(description) + + output.blank() + for key, value in card_dict.items(): + if key.startswith("_"): + continue + formatted = format_value(value) + if formatted: + field_name = format_field_name(key) + if "\n" in formatted: + output.line(f"{field_name}:") + output.line(formatted) else: - print_error(str(e)) - raise click.Abort() - except A2AClientError as e: - log.error("A2A client error: %s", e) - print_error(str(e)) - raise click.Abort() - except httpx.ConnectError: - log.error("Connection refused to %s", agent_url) - print_error(f"Connection refused: Is the server running at {agent_url}?") - raise click.Abort() - except httpx.TimeoutException: - log.error("Request to %s timed out", agent_url) - print_error("Request timed out") - raise click.Abort() - except httpx.HTTPStatusError as e: - log.error( - "HTTP error %d from %s: %s", - e.response.status_code, - agent_url, - e.response.text, - ) - print_error(f"HTTP {e.response.status_code} - {e.response.text}") - raise click.Abort() - except Exception as e: - log.exception("Failed to send message to %s", agent_url) - print_error(str(e)) - raise click.Abort() - - asyncio.run(send_msg()) + output.field(field_name, formatted) -@cli.command() -def tui() -> None: - """Launch the TUI.""" - log.info("Launching TUI") - logging.getLogger().handlers = [] - app = HandlerTUI() - app.run() +@card.command("validate") +@click.argument("source") +@click.option( + "--output", + "-o", + type=click.Choice(["json", "text"]), + default="text", + help="Output format", +) +@click.pass_context +def card_validate(ctx: click.Context, source: str, output: str) -> None: + """Validate an agent card from URL or file.""" + log.info("Validating agent card from %s", source) + is_url = source.startswith(("http://", "https://")) + mode = get_mode(ctx, output) + + async def do_validate() -> None: + with get_output_context(mode) as output: + if is_url: + async with build_http_client() as http_client: + result = await validate_agent_card_from_url(source, http_client) + else: + result = validate_agent_card_from_file(source) + + _format_validation_result(result, output) + + if not result.valid: + raise SystemExit(1) + + asyncio.run(do_validate()) + + +def _format_validation_result(result: ValidationResult, output: Output) -> None: + """Format and display validation result.""" + if output.mode.value == "json": + output.json( + { + "valid": result.valid, + "source": result.source, + "sourceType": result.source_type.value, + "agentName": result.agent_name, + "protocolVersion": result.protocol_version, + "issues": [ + {"field": i.field_name, "message": i.message, "type": i.issue_type} + for i in result.issues + ], + } + ) + return + + if result.valid: + output.success("Valid Agent Card") + output.field("Agent", result.agent_name) + output.field("Protocol Version", result.protocol_version) + output.field("Source", result.source) + else: + output.error("Invalid Agent Card") + output.field("Source", result.source) + output.blank() + output.line(f"Errors ({len(result.issues)}):") + for issue in result.issues: + output.list_item(f"{issue.field_name}: {issue.message}", bullet="✗") + + +# ============================================================================ +# Server Commands +# ============================================================================ + + +@cli.group() +def server() -> None: + """Run local servers.""" + pass + + +@server.command("agent") +@click.option("--host", default="0.0.0.0", help="Host to bind to", show_default=True) +@click.option("--port", default=8000, help="Port to bind to", show_default=True) +def server_agent(host: str, port: int) -> None: + """Start a local A2A agent server.""" + log.info("Starting A2A server on %s:%d", host, port) + run_server(host, port) + + +@server.command("push") +@click.option("--host", default="127.0.0.1", help="Host to bind to", show_default=True) +@click.option("--port", default=9000, help="Port to bind to", show_default=True) +def server_push(host: str, port: int) -> None: + """Start a local webhook server for receiving push notifications.""" + log.info("Starting webhook server on %s:%d", host, port) + run_webhook_server(host, port) + + +# ============================================================================ +# Session Commands +# ============================================================================ + + +@cli.group() +def session() -> None: + """Manage saved session state.""" + pass + + +@session.command("list") +@click.pass_context +def session_list(ctx: click.Context) -> None: + """List all saved sessions.""" + mode = "raw" if ctx.obj.get("raw") else "text" + + with get_output_context(mode) as output: + store = get_session_store() + sessions = store.list_all() + + if not sessions: + output.dim("No saved sessions") + return + + output.header(f"Saved Sessions ({len(sessions)})") + for s in sessions: + output.blank() + output.subheader(s.agent_url) + if s.context_id: + output.field("Context ID", s.context_id, dim_value=True) + if s.task_id: + output.field("Task ID", s.task_id, dim_value=True) + + +@session.command("show") +@click.argument("agent_url") +@click.pass_context +def session_show(ctx: click.Context, agent_url: str) -> None: + """Display session state for an agent.""" + mode = "raw" if ctx.obj.get("raw") else "text" + + with get_output_context(mode) as output: + s = get_session(agent_url) + output.header(f"Session for {agent_url}") + output.field("Context ID", s.context_id or "none", dim_value=not s.context_id) + output.field("Task ID", s.task_id or "none", dim_value=not s.task_id) + + +@session.command("clear") +@click.argument("agent_url", required=False) +@click.option("--all", "-a", "clear_all", is_flag=True, help="Clear all sessions") +@click.pass_context +def session_clear( + ctx: click.Context, agent_url: Optional[str], clear_all: bool +) -> None: + """Clear saved session state.""" + mode = "raw" if ctx.obj.get("raw") else "text" + + with get_output_context(mode) as output: + if clear_all: + clear_session() + output.success("Cleared all sessions") + elif agent_url: + clear_session(agent_url) + output.success(f"Cleared session for {agent_url}") + else: + output.warning("Provide AGENT_URL or use --all to clear sessions") + + +# ============================================================================ +# Utility Commands +# ============================================================================ @cli.command() def version() -> None: """Display the current version.""" - log.debug("Displaying version: %s", __version__) click.echo(__version__) @cli.command() -@click.option("--host", default="0.0.0.0", help="Host to bind to", show_default=True) -@click.option("--port", default=8000, help="Port to bind to", show_default=True) -def server(host: str, port: int) -> None: - """Start the A2A server agent backed by Ollama.""" - log.info("Starting A2A server on %s:%d", host, port) - run_server(host, port) +def tui() -> None: + """Launch the interactive terminal interface.""" + log.info("Launching TUI") + logging.getLogger().handlers = [] + app = HandlerTUI() + app.run() + + +# ============================================================================ +# Entry Point +# ============================================================================ def main() -> None: diff --git a/src/a2a_handler/client.py b/src/a2a_handler/client.py deleted file mode 100644 index 4fa9768..0000000 --- a/src/a2a_handler/client.py +++ /dev/null @@ -1,173 +0,0 @@ -"""A2A protocol client utilities.""" - -import uuid -from dataclasses import dataclass -from typing import Any - -import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import ( - AgentCard, - Message, - Part, - Role, - TextPart, - TransportProtocol, -) - -from a2a_handler.common import get_logger - -log = get_logger(__name__) - -TIMEOUT = 120 - - -def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: - """Build an HTTP client with the specified timeout. - - Args: - timeout: Request timeout in seconds - - Returns: - Configured HTTP client - """ - return httpx.AsyncClient(timeout=timeout) - - -async def fetch_agent_card(agent_url: str, client: httpx.AsyncClient) -> AgentCard: - """Fetch agent card from the specified URL. - - Args: - agent_url: The base URL of the agent - client: HTTP client to use for the request - - Returns: - The agent's card with metadata and capabilities - - Raises: - httpx.RequestError: If the request fails - """ - log.info("Fetching agent card from [url]%s[/url]", agent_url) - resolver = A2ACardResolver(client, agent_url) - card = await resolver.get_agent_card() - log.info("Received card for [agent]%s[/agent]", card.name) - return card - - -def _build_message( - message_text: str, - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Build a message object. - - Args: - message_text: The message content - context_id: Optional context ID for conversation continuity - task_id: Optional task ID to reference - - Returns: - A properly formatted message - """ - return Message( - message_id=str(uuid.uuid4()), - role=Role.user, - parts=[Part(TextPart(text=message_text))], - context_id=context_id, - task_id=task_id, - ) - - -async def send_message_to_agent( - agent_url: str, - message_text: str, - client: httpx.AsyncClient, - context_id: str | None = None, - task_id: str | None = None, -) -> dict[str, Any]: - """Send a message to an agent and return the response. - - Args: - agent_url: The base URL of the agent - message_text: The message to send - client: HTTP client to use - context_id: Optional context ID for conversation continuity - task_id: Optional task ID to reference - - Returns: - Response data as a dictionary - - Raises: - httpx.RequestError: If the request fails - httpx.TimeoutException: If the request times out - """ - log.info("Sending message to [url]%s[/url]: %s", agent_url, message_text[:50]) - card = await fetch_agent_card(agent_url, client) - log.debug("Connected to [agent]%s[/agent]", card.name) - - config = ClientConfig( - httpx_client=client, supported_transports=[TransportProtocol.jsonrpc] - ) - factory = ClientFactory(config) - a2a_client = factory.create(card) - - message = _build_message(message_text, context_id, task_id) - - log.debug("Sending request with ID: %s", message.message_id) - - last_response = None - async for response in a2a_client.send_message(message): - last_response = response - - log.debug("Received response") - - if last_response is None: - return {} - - if isinstance(last_response, tuple): - return last_response[0].model_dump() - - return last_response.model_dump() if hasattr(last_response, "model_dump") else {} - - -@dataclass -class ParsedResponse: - """Parsed A2A response with extracted text content.""" - - text: str - raw: dict[str, Any] - - @property - def has_content(self) -> bool: - """Check if the response has meaningful content.""" - return bool(self.text) - - -def parse_response(response: dict[str, Any]) -> ParsedResponse: - """Parse an A2A response and extract text content. - - Args: - response: Raw response dictionary from send_message_to_agent - - Returns: - ParsedResponse with extracted text and raw data - """ - if not response: - log.debug("Empty response received") - return ParsedResponse(text="", raw=response) - - texts: list[str] = [] - - if "parts" in response: - texts.extend(p.get("text", "") for p in response["parts"]) - log.debug("Extracted %d parts from response", len(response["parts"])) - - for artifact in response.get("artifacts", []): - artifact_parts = artifact.get("parts", []) - texts.extend(p.get("text", "") for p in artifact_parts) - log.debug("Extracted %d parts from artifact", len(artifact_parts)) - - text = "\n".join(t for t in texts if t) - log.debug("Parsed response with %d characters", len(text)) - - return ParsedResponse(text=text, raw=response) diff --git a/src/a2a_handler/common/__init__.py b/src/a2a_handler/common/__init__.py index ab089ac..503ad72 100644 --- a/src/a2a_handler/common/__init__.py +++ b/src/a2a_handler/common/__init__.py @@ -1,5 +1,12 @@ -"""Common utilities for Handler.""" +"""Common utilities for the Handler package. +Provides logging, formatting, and output utilities shared across modules. +""" + +from .formatting import ( + format_field_name, + format_value, +) from .logging import ( HANDLER_THEME, LogLevel, @@ -7,29 +14,21 @@ get_logger, setup_logging, ) -from .printing import ( - BorderStyle, - print_error, - print_info, - print_json, - print_markdown, - print_panel, - print_success, - print_warning, +from .output import ( + Output, + OutputMode, + get_output_context, ) __all__ = [ - "BorderStyle", "HANDLER_THEME", "LogLevel", + "Output", + "OutputMode", "console", + "format_field_name", + "format_value", "get_logger", - "print_error", - "print_info", - "print_json", - "print_markdown", - "print_panel", - "print_success", - "print_warning", + "get_output_context", "setup_logging", ] diff --git a/src/a2a_handler/common/formatting.py b/src/a2a_handler/common/formatting.py new file mode 100644 index 0000000..cfa959c --- /dev/null +++ b/src/a2a_handler/common/formatting.py @@ -0,0 +1,107 @@ +"""Formatting utilities for displaying structured data. + +Provides functions for converting field names and values to human-readable format. +""" + +import re +from typing import Any + + +def format_field_name(name: str) -> str: + """Convert snake_case or camelCase to Title Case. + + Args: + name: The field name to format. + + Returns: + The formatted field name in Title Case. + + Examples: + >>> format_field_name("snake_case") + 'Snake Case' + >>> format_field_name("camelCase") + 'Camel Case' + >>> format_field_name("already Title") + 'Already Title' + """ + name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name) + name = name.replace("_", " ") + return name.title() + + +def format_value(value: Any, indent: int = 0) -> str: + """Recursively format a value for display, returning only truthy content. + + Handles None, empty strings, empty lists/dicts, bools, strings, numbers, + lists (including Pydantic models), dicts, and objects with model_dump(). + + Args: + value: The value to format. + indent: The current indentation level. + + Returns: + A formatted string representation, or empty string for falsy values. + """ + prefix = " " * indent + + if value is None or value == "" or value == [] or value == {}: + return "" + + if isinstance(value, bool): + return "✓" if value else "✗" + + if isinstance(value, str): + return value + + if isinstance(value, int | float): + return str(value) + + if isinstance(value, list): + lines: list[str] = [] + for item in value: + if hasattr(item, "model_dump"): + item_dict: dict[str, Any] = item.model_dump() + name = item_dict.get("name") or item_dict.get("id") or "Item" + desc = item_dict.get("description") or "" + if desc: + desc_prefix = " " * (indent + 1) + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + lines.append(f"{desc_prefix} {desc}") + else: + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + elif isinstance(item, dict): + item_d: dict[str, Any] = item + name = item_d.get("name") or item_d.get("id") or "Item" + desc = item_d.get("description") or "" + if desc: + desc_prefix = " " * (indent + 1) + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + lines.append(f"{desc_prefix} {desc}") + else: + lines.append(f"{prefix} • [cyan]{name}[/cyan]") + else: + formatted = format_value(item, indent) + if formatted: + lines.append(f"{prefix} • {formatted}") + return "\n" + "\n".join(lines) if lines else "" + + if hasattr(value, "model_dump"): + value = value.model_dump() + + if isinstance(value, dict): + dict_lines: list[str] = [] + for k, v in value.items(): + if isinstance(k, str) and k.startswith("_"): + continue + formatted = format_value(v, indent + 1) + if formatted: + field_name = format_field_name(str(k)) + if "\n" in formatted: + dict_lines.append( + f"{prefix}[bold]{field_name}:[/bold]\n{formatted}" + ) + else: + dict_lines.append(f"{prefix}[bold]{field_name}:[/bold] {formatted}") + return "\n".join(dict_lines) if dict_lines else "" + + return str(value) if value else "" diff --git a/src/a2a_handler/common/logging.py b/src/a2a_handler/common/logging.py index 89d8ff4..9f9d6ea 100644 --- a/src/a2a_handler/common/logging.py +++ b/src/a2a_handler/common/logging.py @@ -1,4 +1,7 @@ -"""Unified Rich logging configuration for Handler packages.""" +"""Rich logging configuration for Handler. + +Provides themed console output and structured logging across all modules. +""" import logging from typing import Literal diff --git a/src/a2a_handler/common/output.py b/src/a2a_handler/common/output.py new file mode 100644 index 0000000..4edabfa --- /dev/null +++ b/src/a2a_handler/common/output.py @@ -0,0 +1,180 @@ +"""Output formatting system with mode-aware styling. + +Provides a unified output interface supporting raw, text, and JSON modes. +""" + +from __future__ import annotations + +import json as json_module +import re +from contextlib import contextmanager +from enum import Enum +from typing import Any, Generator + +from rich.console import Console +from rich.markdown import Markdown + +from .logging import console + + +TERMINAL_STATES = {"completed", "failed", "canceled", "rejected"} +SUCCESS_STATES = {"completed"} +ERROR_STATES = {"failed", "rejected"} +WARNING_STATES = {"canceled"} + + +class OutputMode(Enum): + """Output mode for CLI commands.""" + + RAW = "raw" + TEXT = "text" + JSON = "json" + + +def _strip_markup(text: str) -> str: + """Strip Rich markup for raw output.""" + return re.sub(r"\[/?[^\]]+\]", "", text) + + +class Output: + """Manages output mode and styling. + + Provides a unified interface for outputting text, fields, JSON, and + markdown with automatic mode-aware formatting. + """ + + def __init__(self, mode: OutputMode) -> None: + """Initialize output context. + + Args: + mode: Output mode (raw, text, or json) + """ + self.mode = mode + self._raw_console = Console(highlight=False, markup=False) + + def _print(self, text: str, style: str | None = None) -> None: + """Internal print method that respects mode.""" + if self.mode == OutputMode.RAW: + self._raw_console.print(_strip_markup(text)) + elif self.mode == OutputMode.TEXT and style: + console.print(text, style=style) + else: + console.print(text, markup=self.mode == OutputMode.TEXT) + + def line(self, text: str, style: str | None = None) -> None: + """Print a line of text with optional style.""" + self._print(text, style) + + def field( + self, + name: str, + value: Any, + dim_value: bool = False, + value_style: str | None = None, + ) -> None: + """Print a field as 'Name: value' with formatting.""" + value_str = str(value) if value is not None else "none" + + if self.mode == OutputMode.TEXT: + if value_style: + console.print(f"[bold]{name}:[/bold] [{value_style}]{value_str}[/]") + elif dim_value: + console.print(f"[bold]{name}:[/bold] [dim]{value_str}[/dim]") + else: + console.print(f"[bold]{name}:[/bold] {value_str}") + else: + self._raw_console.print(f"{name}: {_strip_markup(value_str)}") + + def header(self, text: str) -> None: + """Print a section header.""" + if self.mode == OutputMode.TEXT: + console.print(f"\n[bold]{text}[/bold]") + else: + self._raw_console.print(f"\n{text}") + + def subheader(self, text: str) -> None: + """Print a subheader (less prominent than header).""" + if self.mode == OutputMode.TEXT: + console.print(f"[bold cyan]{text}[/bold cyan]") + else: + self._raw_console.print(text) + + def blank(self) -> None: + """Print a blank line.""" + if self.mode == OutputMode.TEXT: + console.print() + else: + self._raw_console.print() + + def state(self, name: str, state: str) -> None: + """Print a state field with appropriate coloring.""" + if self.mode == OutputMode.TEXT: + lower = state.lower() + if lower in SUCCESS_STATES: + style = "green" + elif lower in ERROR_STATES: + style = "red" + elif lower in WARNING_STATES: + style = "yellow" + elif lower in TERMINAL_STATES: + style = "bold" + else: + style = "cyan" + console.print(f"[bold]{name}:[/bold] [{style}]{state}[/{style}]") + else: + self._raw_console.print(f"{name}: {state}") + + def success(self, text: str) -> None: + """Print a success message.""" + self._print(text, "green") + + def error(self, text: str) -> None: + """Print an error message.""" + self._print(text, "red bold") + + def warning(self, text: str) -> None: + """Print a warning message.""" + self._print(text, "yellow") + + def dim(self, text: str) -> None: + """Print dimmed/muted text.""" + self._print(text, "dim") + + def json(self, data: Any) -> None: + """Print JSON data.""" + json_str = json_module.dumps(data, indent=2, default=str) + self._raw_console.print(json_str) + + def markdown(self, text: str) -> None: + """Print markdown content.""" + if self.mode == OutputMode.TEXT: + console.print(Markdown(text)) + else: + self._raw_console.print(text) + + def list_item(self, text: str, bullet: str = "•") -> None: + """Print a list item with bullet.""" + if self.mode == OutputMode.TEXT: + console.print(f" [dim]{bullet}[/dim] {text}") + else: + self._raw_console.print(f" {bullet} {_strip_markup(text)}") + + +_current_context: Output | None = None + + +@contextmanager +def get_output_context( + mode: OutputMode | str, +) -> Generator[Output, None, None]: + global _current_context + + if isinstance(mode, str): + mode = OutputMode(mode) + + context = Output(mode) + _current_context = context + try: + yield context + finally: + _current_context = None diff --git a/src/a2a_handler/common/printing.py b/src/a2a_handler/common/printing.py deleted file mode 100644 index 06a9c43..0000000 --- a/src/a2a_handler/common/printing.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Unified Rich printing configuration for Handler packages.""" - -from typing import Literal - -from rich.markdown import Markdown -from rich.panel import Panel -from rich.json import JSON - -from .logging import console - -BorderStyle = Literal["green", "blue", "yellow", "red", "cyan", "magenta", "dim"] - - -def print_panel( - content: str, - title: str | None = None, - border_style: BorderStyle = "green", - expand: bool = False, - markdown: bool = False, -) -> None: - """Print content in a Rich panel. - - Args: - content: The content to display - title: Optional panel title - border_style: Border color/style - expand: Whether to expand panel to full width - markdown: Whether to render content as markdown - """ - renderable = Markdown(content) if markdown else content - console.print( - Panel(renderable, title=title, border_style=border_style, expand=expand) - ) - - -def print_info(content: str, title: str | None = None) -> None: - """Print an info panel (cyan border).""" - print_panel(content, title=title, border_style="cyan") - - -def print_success(content: str, title: str | None = None) -> None: - """Print a success panel (green border).""" - print_panel(content, title=title, border_style="green") - - -def print_warning(content: str, title: str | None = None) -> None: - """Print a warning panel (yellow border).""" - print_panel(content, title=title, border_style="yellow") - - -def print_error(content: str, title: str | None = None) -> None: - """Print an error panel (red border).""" - print_panel(content, title=title, border_style="red") - - -def print_json(data: str, title: str | None = None) -> None: - """Print JSON in a panel with structural highlighting. - - Args: - data: JSON string to display - title: Optional panel title - """ - json_renderable = JSON(data, highlight=False) - console.print( - Panel(json_renderable, title=title, border_style="green", expand=False) - ) - - -def print_markdown( - content: str, title: str | None = None, border_style: BorderStyle = "green" -) -> None: - """Print markdown content in a panel. - - Args: - content: Markdown content to display - title: Optional panel title - border_style: Border color/style - """ - print_panel(content, title=title, border_style=border_style, markdown=True) diff --git a/src/a2a_handler/components/agent_card.py b/src/a2a_handler/components/agent_card.py deleted file mode 100644 index 1cddb4d..0000000 --- a/src/a2a_handler/components/agent_card.py +++ /dev/null @@ -1,231 +0,0 @@ -import json -import logging -import re -from typing import Any - -from a2a.types import AgentCard -from rich.syntax import Syntax -from textual.app import ComposeResult -from textual.binding import Binding -from textual.containers import Container, VerticalScroll -from textual.widgets import Static, TabbedContent, TabPane, Tabs - -logger = logging.getLogger(__name__) - -TEXTUAL_TO_SYNTAX_THEME: dict[str, str] = { - "gruvbox": "gruvbox-dark", - "nord": "nord", - "tokyo-night": "monokai", - "textual-dark": "monokai", - "textual-light": "default", - "solarized-light": "solarized-light", - "dracula": "dracula", - "catppuccin-mocha": "monokai", - "monokai": "monokai", -} - - -class AgentCardPanel(Container): - """Panel displaying agent card information with tabs.""" - - BINDINGS = [ - Binding("h", "prev_tab", "Prev Tab", show=False), - Binding("l", "next_tab", "Next Tab", show=False), - Binding("left", "prev_tab", "Prev Tab", show=False), - Binding("right", "next_tab", "Next Tab", show=False), - Binding("j", "scroll_down", "Scroll Down", show=False), - Binding("k", "scroll_up", "Scroll Up", show=False), - Binding("down", "scroll_down", "Scroll Down", show=False), - Binding("up", "scroll_up", "Scroll Up", show=False), - ] - - can_focus = True - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._agent_card: AgentCard | None = None - - def compose(self) -> ComposeResult: - with TabbedContent(id="agent-card-tabs"): - with TabPane("Short", id="short-tab"): - yield VerticalScroll( - Static("Not connected", id="agent-short"), - id="short-scroll", - ) - with TabPane("Long", id="long-tab"): - yield VerticalScroll( - Static("", id="agent-long"), - id="long-scroll", - ) - with TabPane("Raw", id="raw-tab"): - yield VerticalScroll( - Static("", id="agent-raw"), - id="raw-scroll", - ) - - def on_mount(self) -> None: - self.border_title = "AGENT CARD" - self.border_subtitle = "READY" - for widget in self.query("TabbedContent, Tabs, Tab, TabPane, VerticalScroll"): - widget.can_focus = False - - def _get_syntax_theme(self) -> str: - """Get the Rich Syntax theme name for the current app theme.""" - return TEXTUAL_TO_SYNTAX_THEME.get(self.app.theme or "", "monokai") - - def _format_key(self, key: str) -> str: - """Convert a key to sentence case.""" - spaced = re.sub(r"([a-z])([A-Z])", r"\1 \2", key) - return spaced.replace("_", " ").capitalize() - - def _is_empty(self, value: Any) -> bool: - """Check if a value is truly empty, including nested structures.""" - if value is None: - return True - if isinstance(value, (str, list, dict)) and not value: - return True - if isinstance(value, dict): - return all(self._is_empty(v) for v in value.values()) - if isinstance(value, list): - return all(self._is_empty(v) for v in value) - return False - - def _format_value(self, value: Any, indent: int = 0) -> str: - """Format a nested value for display.""" - prefix = " " * indent - if isinstance(value, dict): - lines = [] - for k, v in value.items(): - if self._is_empty(v): - continue - formatted_key = self._format_key(k) - if isinstance(v, (list, dict)): - lines.append(f"{prefix}[bold]{formatted_key}[/]") - lines.append(self._format_value(v, indent + 1)) - else: - lines.append(f"{prefix}[bold]{formatted_key}:[/] {v}") - return "\n".join(lines) - if isinstance(value, list): - lines = [] - for item in value: - if self._is_empty(item): - continue - if isinstance(item, dict): - lines.append(self._format_value(item, indent)) - else: - lines.append(f"{prefix}• {item}") - return "\n".join(lines) - return f"{prefix}{value}" - - def _build_short_view(self, card: AgentCard) -> str: - """Build the short view with essential fields only.""" - card_dict = card.model_dump() - lines = [] - - short_fields = [ - "name", - "description", - "version", - "url", - "defaultInputModes", - "defaultOutputModes", - ] - - for key in short_fields: - value = card_dict.get(key) - if self._is_empty(value): - continue - formatted_key = self._format_key(key) - if isinstance(value, (list, dict)): - lines.append(f"[bold]{formatted_key}[/]") - lines.append(self._format_value(value, indent=1)) - else: - lines.append(f"[bold]{formatted_key}:[/] {value}") - - return "\n".join(lines) - - def _build_long_view(self, card: AgentCard) -> str: - """Build the long view with all non-empty fields.""" - card_dict = card.model_dump() - lines = [] - - for key, value in card_dict.items(): - if self._is_empty(value): - continue - formatted_key = self._format_key(key) - if isinstance(value, (list, dict)): - lines.append(f"[bold]{formatted_key}[/]") - lines.append(self._format_value(value, indent=1)) - else: - lines.append(f"[bold]{formatted_key}:[/] {value}") - - return "\n".join(lines) - - def update_card(self, card: AgentCard | None) -> None: - """Update the displayed agent card.""" - self._agent_card = card - - if card is None: - self.query_one("#agent-short", Static).update("Not connected") - self.query_one("#agent-long", Static).update("") - self.query_one("#agent-raw", Static).update("") - self.border_subtitle = "READY" - else: - self.query_one("#agent-short", Static).update(self._build_short_view(card)) - self.query_one("#agent-long", Static).update(self._build_long_view(card)) - - json_str = json.dumps(card.model_dump(), indent=2, default=str) - self.query_one("#agent-raw", Static).update( - Syntax(json_str, "json", theme=self._get_syntax_theme()) - ) - self.border_subtitle = "ACTIVE" - - def refresh_theme(self) -> None: - """Refresh the raw view syntax highlighting for theme changes.""" - if self._agent_card is None: - return - json_str = json.dumps(self._agent_card.model_dump(), indent=2, default=str) - self.query_one("#agent-raw", Static).update( - Syntax(json_str, "json", theme=self._get_syntax_theme()) - ) - - def _get_active_scroll(self) -> VerticalScroll | None: - """Get the currently visible scroll container.""" - tabs = self.query_one("#agent-card-tabs", TabbedContent) - active_tab = tabs.active - - if active_tab == "short-tab": - return self.query_one("#short-scroll", VerticalScroll) - elif active_tab == "long-tab": - return self.query_one("#long-scroll", VerticalScroll) - elif active_tab == "raw-tab": - return self.query_one("#raw-scroll", VerticalScroll) - return None - - def action_prev_tab(self) -> None: - """Switch to the previous tab.""" - try: - tabs = self.query_one("#agent-card-tabs Tabs", Tabs) - tabs.action_previous_tab() - except Exception: - pass - - def action_next_tab(self) -> None: - """Switch to the next tab.""" - try: - tabs = self.query_one("#agent-card-tabs Tabs", Tabs) - tabs.action_next_tab() - except Exception: - pass - - def action_scroll_down(self) -> None: - """Scroll down in the active tab's scroll container.""" - scroll = self._get_active_scroll() - if scroll: - scroll.scroll_down() - - def action_scroll_up(self) -> None: - """Scroll up in the active tab's scroll container.""" - scroll = self._get_active_scroll() - if scroll: - scroll.scroll_up() diff --git a/src/a2a_handler/server.py b/src/a2a_handler/server.py index 6e6e8ee..ea66162 100644 --- a/src/a2a_handler/server.py +++ b/src/a2a_handler/server.py @@ -1,21 +1,46 @@ -"""Handler A2A server agent.""" +"""A2A server agent with streaming and push notification support. + +Provides a local A2A-compatible agent server for testing and development. +""" import os +from collections.abc import Awaitable, Callable import click +import httpx import uvicorn +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import ( + BasePushNotificationSender, + InMemoryPushNotificationConfigStore, + InMemoryTaskStore, +) +from a2a.types import AgentCapabilities, AgentCard, AgentSkill from dotenv import load_dotenv -from google.adk.a2a.utils.agent_to_a2a import to_a2a +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.agents.llm_agent import Agent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import ( + InMemoryCredentialService, +) +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from starlette.applications import Starlette from a2a_handler.common import console, get_logger, setup_logging setup_logging(level="INFO", suppress_libs=["uvicorn", "google"]) -log = get_logger(__name__) +logger = get_logger(__name__) + +DEFAULT_OLLAMA_API_BASE = "http://localhost:11434" +DEFAULT_OLLAMA_MODEL = "qwen3" +DEFAULT_HTTP_TIMEOUT_SECONDS = 30 -def create_agent() -> Agent: +def create_llm_agent() -> Agent: """Create and configure the A2A test agent using LiteLLM with Ollama. Returns: @@ -23,44 +48,167 @@ def create_agent() -> Agent: """ load_dotenv() - ollama_base = os.getenv("OLLAMA_API_BASE", "http://localhost:11434") - ollama_model = os.getenv("OLLAMA_MODEL", "qwen3") + ollama_api_base = os.getenv("OLLAMA_API_BASE", DEFAULT_OLLAMA_API_BASE) + ollama_model_name = os.getenv("OLLAMA_MODEL", DEFAULT_OLLAMA_MODEL) - log.info( + logger.info( "Creating agent with model: [highlight]%s[/highlight] at [url]%s[/url]", - ollama_model, - ollama_base, + ollama_model_name, + ollama_api_base, ) - model = LiteLlm( - model=f"ollama_chat/{ollama_model}", - api_base=ollama_base, + language_model = LiteLlm( + model=f"ollama_chat/{ollama_model_name}", + api_base=ollama_api_base, reasoning_effort="none", ) agent = Agent( name="Handler", - model=model, - description="Handler assistant", - instruction="""You are Handler, the resident helpful agent for the Handler application. -You are an expert on the Handler toolkit, which is a terminal-based system for communicating with and testing Agent-to-Agent (A2A) protocol agents. -You know that the Handler project consists of: -1. A TUI (Text User Interface) for interactive agent management -2. A CLI (Command Line Interface) for scripting and quick interactions -3. A Client library (packages/client) that implements the A2A protocol -4. A server agent (packages/server) - which is what you are currently running on! - -You should be helpful, friendly, and eager to explain how Handler works. -If asked about installation, usage, or development, provide clear, concise guidance based on the project structure. -You are proud to be an A2A server agent.""", + model=language_model, + description="Handler's Agent", + instruction="""You are Handler's Agent, the built-in assistant for the Handler application. + +Handler is an A2A (Agent-to-Agent) protocol client published on PyPI as `a2a-handler`. It provides tools for developers to communicate with, test, and debug A2A-compatible agents. + +Handler's architecture consists of: +1. **TUI** - An interactive terminal interface (Textual-based) for managing agent connections, sending messages, and viewing streaming responses +2. **CLI** - A rich-click powered command-line interface for scripting and automation with commands for: + - `message send/stream` - Send messages to agents with optional streaming + - `task get/cancel/resubscribe` - Manage A2A tasks + - `card get/validate` - Retrieve and validate agent cards + - `session list/show/clear` - Manage conversation sessions + - `server agent/push` - Run local servers (including this one!) +3. **A2AService** - A unified service layer wrapping the a2a-sdk for protocol operations +4. **Server Agent** - A local A2A-compatible agent (you!) for testing, built with Google ADK and LiteLLM/Ollama + +Handler supports streaming responses, push notifications, session persistence, and both JSON and formatted text output. + +You are running as Handler's built-in server agent, useful for testing A2A integrations locally. Be helpful, concise, and knowledgeable about both Handler and the A2A protocol.""", ) - log.info( + logger.info( "[success]Agent created successfully:[/success] [agent]%s[/agent]", agent.name ) return agent +def build_agent_card(agent: Agent, host: str, port: int) -> AgentCard: + """Build an AgentCard with streaming and push notification capabilities. + + Args: + agent: The ADK agent + host: Host address for the RPC URL + port: Port number for the RPC URL + + Returns: + Configured AgentCard with capabilities enabled + """ + agent_capabilities = AgentCapabilities( + streaming=True, + push_notifications=True, + ) + + agent_skill = AgentSkill( + id="handler_assistant", + name="Handler Assistant", + description="Answers questions about the Handler A2A toolkit and helps with usage", + tags=["a2a", "handler", "help"], + examples=["What is Handler?", "How do I use the CLI?", "Tell me about A2A"], + ) + + rpc_endpoint_url = f"http://{host}:{port}/" + + logger.debug("Building agent card with RPC URL: %s", rpc_endpoint_url) + + return AgentCard( + name=agent.name, + description=agent.description or "Handler A2A agent", + url=rpc_endpoint_url, + version="1.0.0", + capabilities=agent_capabilities, + skills=[agent_skill], + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + ) + + +def create_runner_factory(agent: Agent) -> Callable[[], Awaitable[Runner]]: + """Create a factory function that builds a Runner for the agent. + + Args: + agent: The ADK agent to wrap + + Returns: + A callable that creates a Runner instance + """ + + async def create_runner() -> Runner: + return Runner( + app_name=agent.name or "handler_agent", + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + credential_service=InMemoryCredentialService(), + ) + + return create_runner + + +def create_a2a_application(agent: Agent, agent_card: AgentCard) -> Starlette: + """Create a Starlette A2A application with full push notification support. + + This is a custom implementation that replaces google-adk's to_a2a() to add + push notification support. The to_a2a() function doesn't pass push_config_store + or push_sender to DefaultRequestHandler, causing push notification operations + to fail with "UnsupportedOperationError". + + Args: + agent: The ADK agent + agent_card: Pre-configured agent card + + Returns: + Configured Starlette application + """ + task_store = InMemoryTaskStore() + push_notification_config_store = InMemoryPushNotificationConfigStore() + http_client = httpx.AsyncClient(timeout=DEFAULT_HTTP_TIMEOUT_SECONDS) + push_notification_sender = BasePushNotificationSender( + http_client, push_notification_config_store + ) + + agent_executor = A2aAgentExecutor( + runner=create_runner_factory(agent), + ) + + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, + task_store=task_store, + push_config_store=push_notification_config_store, + push_sender=push_notification_sender, + ) + + application = Starlette() + + async def setup_a2a_routes() -> None: + a2a_starlette_app = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, + ) + a2a_starlette_app.add_routes_to_app(application) + logger.info("A2A routes configured with push notification support") + + async def cleanup_http_client() -> None: + await http_client.aclose() + logger.info("HTTP client closed") + + application.add_event_handler("startup", setup_a2a_routes) + application.add_event_handler("shutdown", cleanup_http_client) + + return application + + def run_server(host: str, port: int) -> None: """Start the A2A server agent. @@ -71,10 +219,26 @@ def run_server(host: str, port: int) -> None: console.print( f"\n[bold]Starting Handler server on [url]{host}:{port}[/url][/bold]\n" ) - log.info("Initializing A2A server...") - agent = create_agent() - a2a_app = to_a2a(agent, host=host, port=port) - uvicorn.run(a2a_app, host=host, port=port) + logger.info("Initializing A2A server with push notification support...") + + agent = create_llm_agent() + agent_card = build_agent_card(agent, host, port) + + streaming_enabled = ( + agent_card.capabilities.streaming if agent_card.capabilities else False + ) + push_notifications_enabled = ( + agent_card.capabilities.push_notifications if agent_card.capabilities else False + ) + + logger.info( + "Agent card capabilities: streaming=%s, push_notifications=%s", + streaming_enabled, + push_notifications_enabled, + ) + + a2a_application = create_a2a_application(agent, agent_card) + uvicorn.run(a2a_application, host=host, port=port) @click.command() diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py new file mode 100644 index 0000000..3ba0a4f --- /dev/null +++ b/src/a2a_handler/service.py @@ -0,0 +1,518 @@ +"""A2A protocol service layer. + +Provides a unified interface for A2A operations, shared between the CLI and TUI. +""" + +import uuid +from dataclasses import dataclass, field +from typing import Any, AsyncIterator + +import httpx +from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory +from a2a.types import ( + AgentCard, + Message, + Part, + PushNotificationConfig, + Role, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskState, + TaskStatusUpdateEvent, + TextPart, + TransportProtocol, +) + +from a2a_handler.common import get_logger + +logger = get_logger(__name__) + +TERMINAL_TASK_STATES = { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, +} + + +@dataclass +class SendResult: + """Result from sending a message to an agent.""" + + task: Task | None = None + message: Message | None = None + context_id: str | None = None + task_id: str | None = None + state: TaskState | None = None + text: str = "" + raw: dict[str, Any] = field(default_factory=dict) + + @property + def is_complete(self) -> bool: + """Check if the task reached a terminal state.""" + return self.state in TERMINAL_TASK_STATES if self.state else False + + @property + def needs_input(self) -> bool: + """Check if the task is waiting for user input.""" + return self.state == TaskState.input_required if self.state else False + + +@dataclass +class StreamEvent: + """A single event from a streaming response.""" + + event_type: str + task: Task | None = None + message: Message | None = None + status: TaskStatusUpdateEvent | None = None + artifact: TaskArtifactUpdateEvent | None = None + context_id: str | None = None + task_id: str | None = None + state: TaskState | None = None + text: str = "" + + +@dataclass +class TaskResult: + """Result from a task operation (get/cancel).""" + + task: Task + task_id: str + state: TaskState + context_id: str | None = None + text: str = "" + raw: dict[str, Any] = field(default_factory=dict) + + +def extract_text_from_message_parts(message_parts: list[Part] | None) -> str: + """Extract text content from message parts.""" + if not message_parts: + return "" + + extracted_texts = [] + for part in message_parts: + if isinstance(part.root, TextPart): + extracted_texts.append(part.root.text) + + return "\n".join(text for text in extracted_texts if text) + + +def extract_text_from_task(task: Task) -> str: + """Extract text from task artifacts, falling back to history if no artifacts.""" + extracted_texts = [] + + if task.artifacts: + for artifact in task.artifacts: + if artifact.parts: + extracted_texts.append(extract_text_from_message_parts(artifact.parts)) + + # Only check history if no artifacts found (avoids duplication) + if not extracted_texts and task.history: + for message in task.history: + if message.role == Role.agent and message.parts: + extracted_texts.append(extract_text_from_message_parts(message.parts)) + + return "\n".join(text for text in extracted_texts if text) + + +class A2AService: + """High-level service for A2A protocol operations. + + Wraps the a2a-sdk Client and provides a simplified interface + for common operations. Designed to be shared between CLI and TUI. + """ + + def __init__( + self, + http_client: httpx.AsyncClient, + agent_url: str, + enable_streaming: bool = True, + push_notification_url: str | None = None, + push_notification_token: str | None = None, + ) -> None: + """Initialize the A2A service. + + Args: + http_client: Async HTTP client to use for requests + agent_url: Base URL of the A2A agent + enable_streaming: Whether to prefer streaming when available + push_notification_url: Optional webhook URL for push notifications + push_notification_token: Optional token for push notification auth + """ + self.http_client = http_client + self.agent_url = agent_url + self.enable_streaming = enable_streaming + self.push_notification_url = push_notification_url + self.push_notification_token = push_notification_token + self._cached_client: Client | None = None + self._cached_agent_card: AgentCard | None = None + + async def get_card(self) -> AgentCard: + """Fetch and cache the agent card. + + Returns: + The agent's card with metadata and capabilities + """ + if self._cached_agent_card is None: + logger.info("Fetching agent card from %s", self.agent_url) + card_resolver = A2ACardResolver(self.http_client, self.agent_url) + self._cached_agent_card = await card_resolver.get_agent_card() + logger.info("Connected to agent: %s", self._cached_agent_card.name) + return self._cached_agent_card + + async def _get_or_create_client(self) -> Client: + """Get or create the A2A client. + + Returns: + Configured A2A client instance + """ + if self._cached_client is None: + agent_card = await self.get_card() + + push_notification_configs: list[PushNotificationConfig] = [] + if self.push_notification_url: + push_notification_configs.append( + PushNotificationConfig( + url=self.push_notification_url, + token=self.push_notification_token, + ) + ) + logger.info( + "Push notification configured: %s", self.push_notification_url + ) + + client_config = ClientConfig( + httpx_client=self.http_client, + supported_transports=[TransportProtocol.jsonrpc], + streaming=self.enable_streaming, + push_notification_configs=push_notification_configs, + ) + + client_factory = ClientFactory(client_config) + self._cached_client = client_factory.create(agent_card) + logger.debug("Created A2A client for %s", agent_card.name) + + return self._cached_client + + @property + def supports_streaming(self) -> bool: + """Check if the agent supports streaming.""" + if self._cached_agent_card and self._cached_agent_card.capabilities: + return bool(self._cached_agent_card.capabilities.streaming) + return False + + @property + def supports_push_notifications(self) -> bool: + """Check if the agent supports push notifications.""" + if self._cached_agent_card and self._cached_agent_card.capabilities: + return bool(self._cached_agent_card.capabilities.push_notifications) + return False + + def _build_user_message( + self, + message_text: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> Message: + """Build a user message. + + Args: + message_text: Message content + context_id: Optional context ID for conversation continuity + task_id: Optional task ID to continue + + Returns: + Properly formatted Message object + """ + return Message( + message_id=str(uuid.uuid4()), + role=Role.user, + parts=[Part(root=TextPart(text=message_text))], + context_id=context_id, + task_id=task_id, + ) + + async def send( + self, + message_text: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> SendResult: + """Send a message to the agent and wait for completion. + + This method collects all streaming events and returns the final result. + + Args: + message_text: Message to send + context_id: Optional context ID for conversation continuity + task_id: Optional task ID to continue + + Returns: + SendResult with task state, extracted text, and IDs + """ + client = await self._get_or_create_client() + user_message = self._build_user_message(message_text, context_id, task_id) + + truncated_message = ( + message_text[:50] if len(message_text) > 50 else message_text + ) + logger.info("Sending message: %s", truncated_message) + + result = SendResult() + last_received_task: Task | None = None + + async for event in client.send_message(user_message): + if isinstance(event, Message): + result.message = event + result.context_id = event.context_id + result.task_id = event.task_id + result.text = extract_text_from_message_parts(event.parts) + logger.debug("Received message response") + elif isinstance(event, tuple): + received_task, task_update = event + last_received_task = received_task + result.task = received_task + result.task_id = received_task.id + result.context_id = received_task.context_id + if received_task.status: + result.state = received_task.status.state + logger.debug( + "Received task update: %s", + received_task.status.state if received_task.status else "unknown", + ) + + if last_received_task: + result.text = extract_text_from_task(last_received_task) + result.raw = ( + last_received_task.model_dump() + if hasattr(last_received_task, "model_dump") + else {} + ) + elif result.message: + result.raw = ( + result.message.model_dump() + if hasattr(result.message, "model_dump") + else {} + ) + + logger.info("Send complete: task_id=%s, state=%s", result.task_id, result.state) + return result + + async def stream( + self, + message_text: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> AsyncIterator[StreamEvent]: + """Send a message and stream responses as they arrive. + + Args: + message_text: Message to send + context_id: Optional context ID for conversation continuity + task_id: Optional task ID to continue + + Yields: + StreamEvent objects as they are received + """ + client = await self._get_or_create_client() + user_message = self._build_user_message(message_text, context_id, task_id) + + truncated_message = ( + message_text[:50] if len(message_text) > 50 else message_text + ) + logger.info("Streaming message: %s", truncated_message) + + async for event in client.send_message(user_message): + if isinstance(event, Message): + yield StreamEvent( + event_type="message", + message=event, + context_id=event.context_id, + task_id=event.task_id, + text=extract_text_from_message_parts(event.parts), + ) + elif isinstance(event, tuple): + received_task, task_update = event + if isinstance(task_update, TaskStatusUpdateEvent): + status_message_text = "" + if task_update.status and task_update.status.message: + status_message_text = str(task_update.status.message) + yield StreamEvent( + event_type="status", + task=received_task, + status=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=task_update.status.state if task_update.status else None, + text=status_message_text, + ) + elif isinstance(task_update, TaskArtifactUpdateEvent): + artifact_text = "" + if task_update.artifact and task_update.artifact.parts: + artifact_text = extract_text_from_message_parts( + task_update.artifact.parts + ) + yield StreamEvent( + event_type="artifact", + task=received_task, + artifact=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), + text=artifact_text, + ) + else: + yield StreamEvent( + event_type="task", + task=received_task, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), + text=extract_text_from_task(received_task), + ) + + async def get_task( + self, + task_id: str, + history_length: int | None = None, + ) -> TaskResult: + """Get the current state of a task. + + Args: + task_id: ID of the task to retrieve + history_length: Optional number of history messages to include + + Returns: + TaskResult with task state and details + """ + client = await self._get_or_create_client() + + query_params = TaskQueryParams(id=task_id, history_length=history_length) + logger.info("Getting task: %s", task_id) + + task = await client.get_task(query_params) + + return TaskResult( + task=task, + task_id=task.id, + state=task.status.state if task.status else TaskState.unknown, + context_id=task.context_id, + text=extract_text_from_task(task), + raw=task.model_dump() if hasattr(task, "model_dump") else {}, + ) + + async def cancel_task(self, task_id: str) -> TaskResult: + """Cancel a running task. + + Args: + task_id: ID of the task to cancel + + Returns: + TaskResult with updated task state + """ + client = await self._get_or_create_client() + + task_id_params = TaskIdParams(id=task_id) + logger.info("Canceling task: %s", task_id) + + task = await client.cancel_task(task_id_params) + + return TaskResult( + task=task, + task_id=task.id, + state=task.status.state if task.status else TaskState.unknown, + context_id=task.context_id, + text=extract_text_from_task(task), + raw=task.model_dump() if hasattr(task, "model_dump") else {}, + ) + + async def resubscribe(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Resubscribe to a task's event stream. + + Args: + task_id: ID of the task to resubscribe to + + Yields: + StreamEvent objects as they are received + """ + client = await self._get_or_create_client() + + task_id_params = TaskIdParams(id=task_id) + logger.info("Resubscribing to task: %s", task_id) + + async for event in client.resubscribe(task_id_params): + received_task, task_update = event + if isinstance(task_update, TaskStatusUpdateEvent): + yield StreamEvent( + event_type="status", + task=received_task, + status=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=task_update.status.state if task_update.status else None, + ) + elif isinstance(task_update, TaskArtifactUpdateEvent): + artifact_text = "" + if task_update.artifact and task_update.artifact.parts: + artifact_text = extract_text_from_message_parts( + task_update.artifact.parts + ) + yield StreamEvent( + event_type="artifact", + task=received_task, + artifact=task_update, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), + text=artifact_text, + ) + else: + yield StreamEvent( + event_type="task", + task=received_task, + context_id=received_task.context_id, + task_id=received_task.id, + state=( + received_task.status.state if received_task.status else None + ), + text=extract_text_from_task(received_task), + ) + + async def set_push_config( + self, + task_id: str, + webhook_url: str, + authentication_token: str | None = None, + ) -> TaskPushNotificationConfig: + """Set push notification configuration for a task. + + Args: + task_id: ID of the task + webhook_url: Webhook URL to receive notifications + authentication_token: Optional authentication token + + Returns: + The created push notification configuration + """ + client = await self._get_or_create_client() + + push_config = TaskPushNotificationConfig( + task_id=task_id, + push_notification_config=PushNotificationConfig( + url=webhook_url, + token=authentication_token, + ), + ) + logger.info("Setting push config for task %s: %s", task_id, webhook_url) + + return await client.set_task_callback(push_config) diff --git a/src/a2a_handler/session.py b/src/a2a_handler/session.py new file mode 100644 index 0000000..b8eaaa4 --- /dev/null +++ b/src/a2a_handler/session.py @@ -0,0 +1,181 @@ +"""Session state management for the Handler CLI. + +Persists context_id and task_id across CLI invocations for conversation continuity. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from a2a_handler.common import get_logger + +logger = get_logger(__name__) + +DEFAULT_SESSION_DIRECTORY = Path.home() / ".handler" +SESSION_FILENAME = "sessions.json" + + +@dataclass +class AgentSession: + """Session state for a single agent.""" + + agent_url: str + context_id: str | None = None + task_id: str | None = None + + def update( + self, + context_id: str | None = None, + task_id: str | None = None, + ) -> None: + """Update session with new values (only if provided).""" + if context_id is not None: + self.context_id = context_id + if task_id is not None: + self.task_id = task_id + + +@dataclass +class SessionStore: + """Persistent store for agent sessions.""" + + sessions: dict[str, AgentSession] = field(default_factory=dict) + session_directory: Path = field(default_factory=lambda: DEFAULT_SESSION_DIRECTORY) + + @property + def session_file_path(self) -> Path: + """Path to the session file.""" + return self.session_directory / SESSION_FILENAME + + def _ensure_directory_exists(self) -> None: + """Ensure the session directory exists.""" + self.session_directory.mkdir(parents=True, exist_ok=True) + + def load(self) -> None: + """Load sessions from disk.""" + if not self.session_file_path.exists(): + logger.debug("No session file found at %s", self.session_file_path) + return + + try: + with open(self.session_file_path) as session_file: + session_data = json.load(session_file) + + for agent_url, agent_session_data in session_data.items(): + self.sessions[agent_url] = AgentSession( + agent_url=agent_url, + context_id=agent_session_data.get("context_id"), + task_id=agent_session_data.get("task_id"), + ) + + logger.debug( + "Loaded %d sessions from %s", + len(self.sessions), + self.session_file_path, + ) + + except json.JSONDecodeError as error: + logger.warning("Failed to parse session file: %s", error) + except OSError as error: + logger.warning("Failed to read session file: %s", error) + + def save(self) -> None: + """Save sessions to disk.""" + self._ensure_directory_exists() + + session_data: dict[str, Any] = {} + for agent_url, agent_session in self.sessions.items(): + session_data[agent_url] = { + "context_id": agent_session.context_id, + "task_id": agent_session.task_id, + } + + try: + with open(self.session_file_path, "w") as session_file: + json.dump(session_data, session_file, indent=2) + logger.debug( + "Saved %d sessions to %s", + len(self.sessions), + self.session_file_path, + ) + except OSError as error: + logger.warning("Failed to write session file: %s", error) + + def get(self, agent_url: str) -> AgentSession: + """Get or create a session for an agent URL.""" + if agent_url not in self.sessions: + self.sessions[agent_url] = AgentSession(agent_url=agent_url) + logger.debug("Created new session for %s", agent_url) + return self.sessions[agent_url] + + def update( + self, + agent_url: str, + context_id: str | None = None, + task_id: str | None = None, + ) -> AgentSession: + """Update session for an agent and save.""" + agent_session = self.get(agent_url) + agent_session.update(context_id, task_id) + self.save() + logger.debug( + "Updated session for %s: context_id=%s, task_id=%s", + agent_url, + context_id, + task_id, + ) + return agent_session + + def clear(self, agent_url: str | None = None) -> None: + """Clear session(s). + + Args: + agent_url: If provided, clear only that agent's session. + Otherwise, clear all sessions. + """ + if agent_url: + if agent_url in self.sessions: + del self.sessions[agent_url] + logger.info("Cleared session for %s", agent_url) + else: + session_count = len(self.sessions) + self.sessions.clear() + logger.info("Cleared all %d sessions", session_count) + self.save() + + def list_all(self) -> list[AgentSession]: + """List all sessions.""" + return list(self.sessions.values()) + + +_global_session_store: SessionStore | None = None + + +def get_session_store() -> SessionStore: + """Get the global session store (singleton).""" + global _global_session_store + if _global_session_store is None: + _global_session_store = SessionStore() + _global_session_store.load() + logger.debug("Initialized global session store") + return _global_session_store + + +def get_session(agent_url: str) -> AgentSession: + """Get session for an agent URL.""" + return get_session_store().get(agent_url) + + +def update_session( + agent_url: str, + context_id: str | None = None, + task_id: str | None = None, +) -> AgentSession: + """Update and persist session for an agent.""" + return get_session_store().update(agent_url, context_id, task_id) + + +def clear_session(agent_url: str | None = None) -> None: + """Clear session(s).""" + get_session_store().clear(agent_url) diff --git a/src/a2a_handler/tui.py b/src/a2a_handler/tui.py deleted file mode 100644 index 0b34b5e..0000000 --- a/src/a2a_handler/tui.py +++ /dev/null @@ -1,240 +0,0 @@ -import logging -import uuid -from typing import Any - -import httpx -from a2a.types import AgentCard -from importlib.metadata import version - -__version__ = version("a2a-handler") -from a2a_handler.client import ( - build_http_client, - fetch_agent_card, - send_message_to_agent, -) -from textual import on -from textual.app import App, ComposeResult -from textual.binding import Binding -from textual.containers import Container, Vertical -from textual.logging import TextualHandler -from textual.widgets import Button, Input - -from a2a_handler.components import ( - AgentCardPanel, - ContactPanel, - Footer, - InputPanel, - MessagesPanel, -) - -logging.basicConfig( - level="NOTSET", - handlers=[TextualHandler()], -) -logger = logging.getLogger(__name__) - - -class HandlerTUI(App[Any]): - """Handler - A2A Agent Management Interface.""" - - CSS_PATH = "tui.tcss" - - BINDINGS = [ - Binding("ctrl+q", "quit", "Quit", show=True), - Binding("ctrl+c", "clear_chat", "Clear", show=True), - Binding("ctrl+p", "command_palette", "Palette", show=True), - ] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.agent_card: AgentCard | None = None - self.http_client: httpx.AsyncClient | None = None - self.context_id: str | None = None - self.agent_url: str | None = None - - def compose(self) -> ComposeResult: - with Container(id="root-container"): - with Vertical(id="left-pane"): - yield ContactPanel(id="contact-container", classes="panel") - yield AgentCardPanel(id="agent-card-container", classes="panel") - - with Vertical(id="right-pane"): - yield MessagesPanel(id="messages-container", classes="panel") - yield InputPanel(id="input-container", classes="panel") - - yield Footer(id="footer") - - async def on_mount(self) -> None: - logger.info("TUI application starting") - self.http_client = build_http_client() - self.theme = "gruvbox" - - root = self.query_one("#root-container", Container) - root.border_title = f"Handler v{__version__} [red]●[/red] Disconnected" - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("Welcome! Connect to an agent to start chatting.") - - def watch_theme(self, theme: str) -> None: - """Called when the app theme changes.""" - logger.debug("Theme changed to: %s", theme) - self.query_one("#agent-card-container", AgentCardPanel).refresh_theme() - - @on(Button.Pressed, "#footer-quit") - async def action_quit_app(self) -> None: - await self.action_quit() - - @on(Button.Pressed, "#footer-clear") - async def action_clear_chat_footer(self) -> None: - await self.action_clear_chat() - - @on(Button.Pressed, "#footer-palette") - def action_open_command_palette(self) -> None: - self.action_command_palette() - - async def _connect(self, agent_url: str) -> AgentCard: - if not self.http_client: - raise RuntimeError("HTTP client not initialized") - - logger.info("Connecting to agent at %s", agent_url) - return await fetch_agent_card(agent_url, self.http_client) - - def _update_ui_connected(self, card: AgentCard) -> None: - root = self.query_one("#root-container", Container) - root.border_title = ( - f"Handler v{__version__} [green]●[/green] Connected: {card.name}" - ) - - self.query_one("#agent-card-container", AgentCardPanel).update_card(card) - self.query_one("#contact-container", ContactPanel).set_connected(True) - self.query_one("#messages-container", MessagesPanel).update_message_count() - - def _update_ui_disconnected(self) -> None: - root = self.query_one("#root-container", Container) - root.border_title = f"Handler v{__version__} [red]●[/red] Disconnected" - - self.query_one("#agent-card-container", AgentCardPanel).update_card(None) - self.query_one("#contact-container", ContactPanel).set_connected(False) - - @on(Button.Pressed, "#connect-btn") - async def connect_to_agent(self) -> None: - contact = self.query_one("#contact-container", ContactPanel) - agent_url = contact.get_url() - - if not agent_url: - logger.warning("Connect attempted with empty URL") - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Please enter an agent URL") - return - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message(f"Connecting to {agent_url}...") - - try: - card = await self._connect(agent_url) - - self.agent_card = card - self.agent_url = agent_url - self.context_id = str(uuid.uuid4()) - - logger.info("Successfully connected to %s", card.name) - - self._update_ui_connected(card) - messages.add_system_message(f"✓ Connected to {card.name}") - self.query_one("#agent-card-container", AgentCardPanel).focus() - - except Exception as e: - logger.error("Connection failed: %s", e, exc_info=True) - messages.add_system_message(f"✗ Connection failed: {str(e)}") - - @on(Button.Pressed, "#disconnect-btn") - def disconnect_from_agent(self) -> None: - logger.info("Disconnecting from %s", self.agent_url) - self.agent_card = None - self.agent_url = None - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("Disconnected") - - self._update_ui_disconnected() - - @on(Input.Submitted, "#message-input") - async def send_on_enter(self) -> None: - if self.agent_url: - await self._send_message() - else: - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Not connected to an agent") - - @on(Button.Pressed, "#send-btn") - async def send_button_pressed(self) -> None: - if self.agent_url: - await self._send_message() - else: - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Not connected to an agent") - - async def _handle_agent_response(self, response_data: dict[str, Any]) -> str: - if not response_data: - return "Error: No result in response" - - texts = [] - if "parts" in response_data: - texts.extend(p.get("text", "") for p in response_data["parts"]) - - for artifact in response_data.get("artifacts", []): - texts.extend(p.get("text", "") for p in artifact.get("parts", [])) - - return "\n".join(t for t in texts if t) or "No text in response" - - async def _send_message(self) -> None: - if not self.agent_url or not self.http_client: - logger.warning("Attempted to send message without connection") - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_system_message("✗ Not connected to an agent") - return - - input_panel = self.query_one("#input-container", InputPanel) - message_text = input_panel.get_message() - - if not message_text: - return - - messages = self.query_one("#messages-container", MessagesPanel) - messages.add_message("user", message_text) - - try: - logger.info("Sending message: %s", message_text[:50]) - - response_data = await send_message_to_agent( - self.agent_url, - message_text, - self.http_client, - context_id=self.context_id, - ) - - response_text = await self._handle_agent_response(response_data) - messages.add_message("agent", response_text) - - except Exception as e: - logger.error("Error sending message: %s", e, exc_info=True) - messages.add_system_message(f"✗ Error: {str(e)}") - - async def action_clear_chat(self) -> None: - messages = self.query_one("#messages-container", MessagesPanel) - await messages.clear() - - async def on_unmount(self) -> None: - logger.info("Shutting down TUI application") - if self.http_client: - await self.http_client.aclose() - - -def main() -> None: - """Entry point for the TUI application.""" - app = HandlerTUI() - app.run() - - -if __name__ == "__main__": - main() diff --git a/src/a2a_handler/tui/__init__.py b/src/a2a_handler/tui/__init__.py new file mode 100644 index 0000000..d16e563 --- /dev/null +++ b/src/a2a_handler/tui/__init__.py @@ -0,0 +1,8 @@ +"""Handler TUI application. + +Provides an interactive terminal interface for communicating with A2A agents. +""" + +from a2a_handler.tui.app import HandlerTUI, main + +__all__ = ["HandlerTUI", "main"] diff --git a/src/a2a_handler/tui/app.py b/src/a2a_handler/tui/app.py new file mode 100644 index 0000000..98587cb --- /dev/null +++ b/src/a2a_handler/tui/app.py @@ -0,0 +1,255 @@ +"""Main TUI application for Handler. + +Provides the Textual-based terminal interface for agent interaction. +""" + +import logging +import uuid +from importlib.metadata import version +from typing import Any + +import httpx +from a2a.types import AgentCard +from textual import on +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Container, Vertical +from textual.logging import TextualHandler +from textual.widgets import Button, Input + +from a2a_handler.service import A2AService +from a2a_handler.tui.components import ( + AgentCardPanel, + ContactPanel, + Footer, + InputPanel, + MessagesPanel, +) + +__version__ = version("a2a-handler") + +logging.basicConfig( + level="NOTSET", + handlers=[TextualHandler()], +) +logger = logging.getLogger(__name__) + +DEFAULT_HTTP_TIMEOUT_SECONDS = 120 + + +def build_http_client( + timeout_seconds: int = DEFAULT_HTTP_TIMEOUT_SECONDS, +) -> httpx.AsyncClient: + """Build an HTTP client with the specified timeout.""" + return httpx.AsyncClient(timeout=timeout_seconds) + + +class HandlerTUI(App[Any]): + """Handler - A2A Agent Management Interface.""" + + CSS_PATH = "app.tcss" + + BINDINGS = [ + Binding("ctrl+q", "quit", "Quit", show=True), + Binding("ctrl+c", "clear_chat", "Clear", show=True), + Binding("ctrl+p", "command_palette", "Palette", show=True), + ] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.current_agent_card: AgentCard | None = None + self.http_client: httpx.AsyncClient | None = None + self.current_context_id: str | None = None + self.current_agent_url: str | None = None + self._agent_service: A2AService | None = None + + def compose(self) -> ComposeResult: + with Container(id="root-container"): + with Vertical(id="left-pane"): + yield ContactPanel(id="contact-container", classes="panel") + yield AgentCardPanel(id="agent-card-container", classes="panel") + + with Vertical(id="right-pane"): + yield MessagesPanel(id="messages-container", classes="panel") + yield InputPanel(id="input-container", classes="panel") + + yield Footer(id="footer") + + async def on_mount(self) -> None: + logger.info("TUI application starting") + self.http_client = build_http_client() + self.theme = "gruvbox" + + root_container = self.query_one("#root-container", Container) + root_container.border_title = ( + f"Handler v{__version__} [red]●[/red] Disconnected" + ) + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message( + "Welcome! Connect to an agent to start chatting." + ) + + def watch_theme(self, new_theme: str) -> None: + """Called when the app theme changes.""" + logger.debug("Theme changed to: %s", new_theme) + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.refresh_theme() + + @on(Button.Pressed, "#footer-quit") + async def handle_quit_button(self) -> None: + await self.action_quit() + + @on(Button.Pressed, "#footer-clear") + async def handle_clear_button(self) -> None: + await self.action_clear_chat() + + @on(Button.Pressed, "#footer-palette") + def handle_palette_button(self) -> None: + self.action_command_palette() + + async def _connect_to_agent(self, agent_url: str) -> AgentCard: + if not self.http_client: + raise RuntimeError("HTTP client not initialized") + + logger.info("Connecting to agent at %s", agent_url) + self._agent_service = A2AService(self.http_client, agent_url) + return await self._agent_service.get_card() + + def _update_ui_for_connected_state(self, agent_card: AgentCard) -> None: + root_container = self.query_one("#root-container", Container) + root_container.border_title = ( + f"Handler v{__version__} [green]●[/green] Connected: {agent_card.name}" + ) + + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.update_card(agent_card) + + contact_panel = self.query_one("#contact-container", ContactPanel) + contact_panel.set_connected(True) + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.update_message_count() + + def _update_ui_for_disconnected_state(self) -> None: + root_container = self.query_one("#root-container", Container) + root_container.border_title = ( + f"Handler v{__version__} [red]●[/red] Disconnected" + ) + + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.update_card(None) + + contact_panel = self.query_one("#contact-container", ContactPanel) + contact_panel.set_connected(False) + + @on(Button.Pressed, "#connect-btn") + async def handle_connect_button(self) -> None: + contact_panel = self.query_one("#contact-container", ContactPanel) + agent_url = contact_panel.get_url() + + if not agent_url: + logger.warning("Connect attempted with empty URL") + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Please enter an agent URL") + return + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message(f"Connecting to {agent_url}...") + + try: + agent_card = await self._connect_to_agent(agent_url) + + self.current_agent_card = agent_card + self.current_agent_url = agent_url + self.current_context_id = str(uuid.uuid4()) + + logger.info("Successfully connected to %s", agent_card.name) + + self._update_ui_for_connected_state(agent_card) + messages_panel.add_system_message(f"✓ Connected to {agent_card.name}") + + agent_card_panel = self.query_one("#agent-card-container", AgentCardPanel) + agent_card_panel.focus() + + except Exception as error: + logger.error("Connection failed: %s", error, exc_info=True) + messages_panel.add_system_message(f"✗ Connection failed: {error!s}") + + @on(Button.Pressed, "#disconnect-btn") + def handle_disconnect_button(self) -> None: + logger.info("Disconnecting from %s", self.current_agent_url) + self.current_agent_card = None + self.current_agent_url = None + self._agent_service = None + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("Disconnected") + + self._update_ui_for_disconnected_state() + + @on(Input.Submitted, "#message-input") + async def handle_message_submit(self) -> None: + if self.current_agent_url: + await self._send_message() + else: + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Not connected to an agent") + + @on(Button.Pressed, "#send-btn") + async def handle_send_button(self) -> None: + if self.current_agent_url: + await self._send_message() + else: + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Not connected to an agent") + + async def _send_message(self) -> None: + if not self.current_agent_url or not self._agent_service: + logger.warning("Attempted to send message without connection") + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_system_message("✗ Not connected to an agent") + return + + input_panel = self.query_one("#input-container", InputPanel) + message_text = input_panel.get_message() + + if not message_text: + return + + messages_panel = self.query_one("#messages-container", MessagesPanel) + messages_panel.add_message("user", message_text) + + try: + logger.info("Sending message: %s", message_text[:50]) + + send_result = await self._agent_service.send( + message_text, + context_id=self.current_context_id, + ) + + response_text = send_result.text or "No text in response" + messages_panel.add_message("agent", response_text) + + except Exception as error: + logger.error("Error sending message: %s", error, exc_info=True) + messages_panel.add_system_message(f"✗ Error: {error!s}") + + async def action_clear_chat(self) -> None: + messages_panel = self.query_one("#messages-container", MessagesPanel) + await messages_panel.clear() + + async def on_unmount(self) -> None: + logger.info("Shutting down TUI application") + if self.http_client: + await self.http_client.aclose() + + +def main() -> None: + """Entry point for the TUI application.""" + application = HandlerTUI() + application.run() + + +if __name__ == "__main__": + main() diff --git a/src/a2a_handler/tui.tcss b/src/a2a_handler/tui/app.tcss similarity index 100% rename from src/a2a_handler/tui.tcss rename to src/a2a_handler/tui/app.tcss diff --git a/src/a2a_handler/components/__init__.py b/src/a2a_handler/tui/components/__init__.py similarity index 74% rename from src/a2a_handler/components/__init__.py rename to src/a2a_handler/tui/components/__init__.py index aff261c..617ac3c 100644 --- a/src/a2a_handler/components/__init__.py +++ b/src/a2a_handler/tui/components/__init__.py @@ -1,4 +1,6 @@ -from .agent_card import AgentCardPanel +"""TUI component widgets for the Handler application.""" + +from .card import AgentCardPanel from .contact import ContactPanel from .footer import Footer from .input import InputPanel diff --git a/src/a2a_handler/tui/components/card.py b/src/a2a_handler/tui/components/card.py new file mode 100644 index 0000000..eb3df37 --- /dev/null +++ b/src/a2a_handler/tui/components/card.py @@ -0,0 +1,266 @@ +"""Agent card panel component for displaying agent metadata and capabilities.""" + +import json +import re +from typing import Any + +from a2a.types import AgentCard +from rich.syntax import Syntax +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, VerticalScroll +from textual.widgets import Static, TabbedContent, TabPane, Tabs + +from a2a_handler.common import get_logger + +logger = get_logger(__name__) + +TEXTUAL_TO_SYNTAX_THEME_MAP: dict[str, str] = { + "gruvbox": "gruvbox-dark", + "nord": "nord", + "tokyo-night": "monokai", + "textual-dark": "monokai", + "textual-light": "default", + "solarized-light": "solarized-light", + "dracula": "dracula", + "catppuccin-mocha": "monokai", + "monokai": "monokai", +} + +SHORT_VIEW_FIELDS = [ + "name", + "description", + "version", + "url", + "defaultInputModes", + "defaultOutputModes", +] + + +class AgentCardPanel(Container): + """Panel displaying agent card information with tabs.""" + + BINDINGS = [ + Binding("h", "previous_tab", "Previous Tab", show=False), + Binding("l", "next_tab", "Next Tab", show=False), + Binding("left", "previous_tab", "Previous Tab", show=False), + Binding("right", "next_tab", "Next Tab", show=False), + Binding("j", "scroll_down", "Scroll Down", show=False), + Binding("k", "scroll_up", "Scroll Up", show=False), + Binding("down", "scroll_down", "Scroll Down", show=False), + Binding("up", "scroll_up", "Scroll Up", show=False), + ] + + can_focus = True + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._current_agent_card: AgentCard | None = None + + def compose(self) -> ComposeResult: + with TabbedContent(id="agent-card-tabs"): + with TabPane("Short", id="short-tab"): + yield VerticalScroll( + Static("Not connected", id="agent-short"), + id="short-scroll", + ) + with TabPane("Long", id="long-tab"): + yield VerticalScroll( + Static("", id="agent-long"), + id="long-scroll", + ) + with TabPane("Raw", id="raw-tab"): + yield VerticalScroll( + Static("", id="agent-raw"), + id="raw-scroll", + ) + + def on_mount(self) -> None: + self.border_title = "AGENT CARD" + self.border_subtitle = "READY" + for widget in self.query("TabbedContent, Tabs, Tab, TabPane, VerticalScroll"): + widget.can_focus = False + logger.debug("Agent card panel mounted") + + def _get_syntax_theme_for_current_app_theme(self) -> str: + """Get the Rich Syntax theme name for the current app theme.""" + current_theme = self.app.theme or "" + return TEXTUAL_TO_SYNTAX_THEME_MAP.get(current_theme, "monokai") + + def _convert_key_to_sentence_case(self, key: str) -> str: + """Convert a camelCase or snake_case key to sentence case.""" + spaced_key = re.sub(r"([a-z])([A-Z])", r"\1 \2", key) + return spaced_key.replace("_", " ").capitalize() + + def _is_value_empty(self, value: Any) -> bool: + """Check if a value is truly empty, including nested structures.""" + if value is None: + return True + if isinstance(value, (str, list, dict)) and not value: + return True + if isinstance(value, dict): + return all( + self._is_value_empty(nested_value) for nested_value in value.values() + ) + if isinstance(value, list): + return all(self._is_value_empty(item) for item in value) + return False + + def _format_nested_value(self, value: Any, indentation_level: int = 0) -> str: + """Format a nested value for display with proper indentation.""" + indentation_prefix = " " * indentation_level + + if isinstance(value, dict): + formatted_lines = [] + for key, nested_value in value.items(): + if self._is_value_empty(nested_value): + continue + formatted_key = self._convert_key_to_sentence_case(key) + if isinstance(nested_value, (list, dict)): + formatted_lines.append( + f"{indentation_prefix}[bold]{formatted_key}[/]" + ) + formatted_lines.append( + self._format_nested_value(nested_value, indentation_level + 1) + ) + else: + formatted_lines.append( + f"{indentation_prefix}[bold]{formatted_key}:[/] {nested_value}" + ) + return "\n".join(formatted_lines) + + if isinstance(value, list): + formatted_lines = [] + for item in value: + if self._is_value_empty(item): + continue + if isinstance(item, dict): + formatted_lines.append( + self._format_nested_value(item, indentation_level) + ) + else: + formatted_lines.append(f"{indentation_prefix}• {item}") + return "\n".join(formatted_lines) + + return f"{indentation_prefix}{value}" + + def _build_short_view_content(self, agent_card: AgentCard) -> str: + """Build the short view with essential fields only.""" + card_data = agent_card.model_dump() + formatted_lines = [] + + for field_name in SHORT_VIEW_FIELDS: + field_value = card_data.get(field_name) + if self._is_value_empty(field_value): + continue + formatted_key = self._convert_key_to_sentence_case(field_name) + if isinstance(field_value, (list, dict)): + formatted_lines.append(f"[bold]{formatted_key}[/]") + formatted_lines.append( + self._format_nested_value(field_value, indentation_level=1) + ) + else: + formatted_lines.append(f"[bold]{formatted_key}:[/] {field_value}") + + return "\n".join(formatted_lines) + + def _build_long_view_content(self, agent_card: AgentCard) -> str: + """Build the long view with all non-empty fields.""" + card_data = agent_card.model_dump() + formatted_lines = [] + + for field_name, field_value in card_data.items(): + if self._is_value_empty(field_value): + continue + formatted_key = self._convert_key_to_sentence_case(field_name) + if isinstance(field_value, (list, dict)): + formatted_lines.append(f"[bold]{formatted_key}[/]") + formatted_lines.append( + self._format_nested_value(field_value, indentation_level=1) + ) + else: + formatted_lines.append(f"[bold]{formatted_key}:[/] {field_value}") + + return "\n".join(formatted_lines) + + def update_card(self, agent_card: AgentCard | None) -> None: + """Update the displayed agent card.""" + self._current_agent_card = agent_card + + short_view_widget = self.query_one("#agent-short", Static) + long_view_widget = self.query_one("#agent-long", Static) + raw_view_widget = self.query_one("#agent-raw", Static) + + if agent_card is None: + logger.debug("Clearing agent card display") + short_view_widget.update("Not connected") + long_view_widget.update("") + raw_view_widget.update("") + self.border_subtitle = "READY" + else: + logger.info("Displaying agent card for: %s", agent_card.name) + short_view_widget.update(self._build_short_view_content(agent_card)) + long_view_widget.update(self._build_long_view_content(agent_card)) + + json_content = json.dumps(agent_card.model_dump(), indent=2, default=str) + syntax_theme = self._get_syntax_theme_for_current_app_theme() + raw_view_widget.update(Syntax(json_content, "json", theme=syntax_theme)) + self.border_subtitle = "ACTIVE" + + def refresh_theme(self) -> None: + """Refresh the raw view syntax highlighting for theme changes.""" + if self._current_agent_card is None: + return + + logger.debug("Refreshing syntax theme for agent card raw view") + json_content = json.dumps( + self._current_agent_card.model_dump(), indent=2, default=str + ) + syntax_theme = self._get_syntax_theme_for_current_app_theme() + self.query_one("#agent-raw", Static).update( + Syntax(json_content, "json", theme=syntax_theme) + ) + + def _get_currently_active_scroll_container(self) -> VerticalScroll | None: + """Get the currently visible scroll container.""" + tabbed_content = self.query_one("#agent-card-tabs", TabbedContent) + active_tab_id = tabbed_content.active + + scroll_container_map = { + "short-tab": "#short-scroll", + "long-tab": "#long-scroll", + "raw-tab": "#raw-scroll", + } + + scroll_container_id = scroll_container_map.get(active_tab_id) + if scroll_container_id: + return self.query_one(scroll_container_id, VerticalScroll) + return None + + def action_previous_tab(self) -> None: + """Switch to the previous tab.""" + try: + tabs_widget = self.query_one("#agent-card-tabs Tabs", Tabs) + tabs_widget.action_previous_tab() + except Exception: + pass + + def action_next_tab(self) -> None: + """Switch to the next tab.""" + try: + tabs_widget = self.query_one("#agent-card-tabs Tabs", Tabs) + tabs_widget.action_next_tab() + except Exception: + pass + + def action_scroll_down(self) -> None: + """Scroll down in the active tab's scroll container.""" + scroll_container = self._get_currently_active_scroll_container() + if scroll_container: + scroll_container.scroll_down() + + def action_scroll_up(self) -> None: + """Scroll up in the active tab's scroll container.""" + scroll_container = self._get_currently_active_scroll_container() + if scroll_container: + scroll_container.scroll_up() diff --git a/src/a2a_handler/components/contact.py b/src/a2a_handler/tui/components/contact.py similarity index 54% rename from src/a2a_handler/components/contact.py rename to src/a2a_handler/tui/components/contact.py index 4fad077..2f6acdc 100644 --- a/src/a2a_handler/components/contact.py +++ b/src/a2a_handler/tui/components/contact.py @@ -1,10 +1,12 @@ -import logging +"""Contact panel component for managing agent connections.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Button, Input, Label -logger = logging.getLogger(__name__) +from a2a_handler.common import get_logger + +logger = get_logger(__name__) class ContactPanel(Container): @@ -23,12 +25,19 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.border_title = "CONTACT" + logger.debug("Contact panel mounted") - def set_connected(self, connected: bool) -> None: + def set_connected(self, is_connected: bool) -> None: """Update button states based on connection status.""" - self.query_one("#connect-btn", Button).disabled = connected - self.query_one("#disconnect-btn", Button).disabled = not connected + connect_button = self.query_one("#connect-btn", Button) + disconnect_button = self.query_one("#disconnect-btn", Button) + + connect_button.disabled = is_connected + disconnect_button.disabled = not is_connected + + logger.debug("Connection state updated: connected=%s", is_connected) def get_url(self) -> str: - """Get the current agent URL.""" - return self.query_one("#agent-url", Input).value.strip() + """Get the current agent URL from the input field.""" + url_input = self.query_one("#agent-url", Input) + return url_input.value.strip() diff --git a/src/a2a_handler/components/footer.py b/src/a2a_handler/tui/components/footer.py similarity index 89% rename from src/a2a_handler/components/footer.py rename to src/a2a_handler/tui/components/footer.py index b428611..070387c 100644 --- a/src/a2a_handler/components/footer.py +++ b/src/a2a_handler/tui/components/footer.py @@ -1,3 +1,5 @@ +"""Footer component displaying keyboard shortcut buttons.""" + from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Button diff --git a/src/a2a_handler/components/input.py b/src/a2a_handler/tui/components/input.py similarity index 74% rename from src/a2a_handler/components/input.py rename to src/a2a_handler/tui/components/input.py index 8a3508a..74c5f4d 100644 --- a/src/a2a_handler/components/input.py +++ b/src/a2a_handler/tui/components/input.py @@ -1,10 +1,12 @@ -import logging +"""Input panel component for composing and sending messages.""" from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Button, Input -logger = logging.getLogger(__name__) +from a2a_handler.common import get_logger + +logger = get_logger(__name__) class InputPanel(Container): @@ -19,14 +21,15 @@ def on_mount(self) -> None: self.border_title = "INPUT" self.border_subtitle = "PRESS ENTER TO SEND" self.query_one("#send-btn", Button).can_focus = False + logger.debug("Input panel mounted") def get_message(self) -> str: """Get and clear the current message input.""" message_input = self.query_one("#message-input", Input) - message = message_input.value.strip() + message_text = message_input.value.strip() message_input.value = "" - return message + return message_text def focus_input(self) -> None: - """Focus the message input.""" + """Focus the message input field.""" self.query_one("#message-input", Input).focus() diff --git a/src/a2a_handler/components/messages.py b/src/a2a_handler/tui/components/messages.py similarity index 59% rename from src/a2a_handler/components/messages.py rename to src/a2a_handler/tui/components/messages.py index 113ba21..0822821 100644 --- a/src/a2a_handler/components/messages.py +++ b/src/a2a_handler/tui/components/messages.py @@ -1,4 +1,5 @@ -import logging +"""Messages panel component for displaying chat history.""" + from datetime import datetime from typing import Any @@ -7,7 +8,9 @@ from textual.containers import Container, Vertical, VerticalScroll from textual.widgets import Static -logger = logging.getLogger(__name__) +from a2a_handler.common import get_logger + +logger = get_logger(__name__) class Message(Vertical): @@ -26,16 +29,16 @@ def __init__( self.timestamp = timestamp or datetime.now() def compose(self) -> ComposeResult: - time_str = self.timestamp.strftime("%H:%M:%S") + formatted_time = self.timestamp.strftime("%H:%M:%S") if self.role == "system": - yield Static(f"[dim]{time_str}[/dim] [italic]{self.text}[/italic]") + yield Static(f"[dim]{formatted_time}[/dim] [italic]{self.text}[/italic]") else: role_color = "#88c0d0" if self.role == "agent" else "#bf616a" - yield Static(f"[dim]{time_str}[/dim] [{role_color}]{self.text}[/]") + yield Static(f"[dim]{formatted_time}[/dim] [{role_color}]{self.text}[/]") -class ChatScroll(VerticalScroll): +class ChatScrollContainer(VerticalScroll): """Scrollable chat area.""" can_focus = False @@ -54,37 +57,39 @@ class MessagesPanel(Container): can_focus = True def compose(self) -> ComposeResult: - yield ChatScroll(id="chat") + yield ChatScrollContainer(id="chat") def on_mount(self) -> None: self.border_title = "MESSAGES" + logger.debug("Messages panel mounted") - def _get_chat(self) -> ChatScroll: - return self.query_one("#chat", ChatScroll) + def _get_chat_container(self) -> ChatScrollContainer: + return self.query_one("#chat", ChatScrollContainer) def add_message(self, role: str, content: str) -> None: logger.debug("Adding %s message: %s", role, content[:50]) - chat = self._get_chat() - message = Message(role, content) - chat.mount(message) - chat.scroll_end(animate=False) + chat_container = self._get_chat_container() + message_widget = Message(role, content) + chat_container.mount(message_widget) + chat_container.scroll_end(animate=False) def add_system_message(self, content: str) -> None: logger.info("System message: %s", content) self.add_message("system", content) def update_message_count(self) -> None: - chat = self._get_chat() - self.border_subtitle = f"{len(chat.children)} MESSAGES" + chat_container = self._get_chat_container() + message_count = len(chat_container.children) + self.border_subtitle = f"{message_count} MESSAGES" async def clear(self) -> None: - logger.info("Clearing chat") - chat = self._get_chat() - await chat.remove_children() + logger.info("Clearing chat messages") + chat_container = self._get_chat_container() + await chat_container.remove_children() self.add_system_message("Chat cleared") def action_scroll_down(self) -> None: - self._get_chat().scroll_down() + self._get_chat_container().scroll_down() def action_scroll_up(self) -> None: - self._get_chat().scroll_up() + self._get_chat_container().scroll_up() diff --git a/src/a2a_handler/validation.py b/src/a2a_handler/validation.py index 292c48f..7cfa0bb 100644 --- a/src/a2a_handler/validation.py +++ b/src/a2a_handler/validation.py @@ -1,4 +1,7 @@ -"""A2A protocol validation utilities.""" +"""Agent card validation utilities for the A2A protocol. + +Validates agent cards from URLs or local files using the A2A SDK. +""" import json from dataclasses import dataclass, field @@ -7,12 +10,13 @@ from typing import Any import httpx +from a2a.client import A2ACardResolver from a2a.types import AgentCard from pydantic import ValidationError from a2a_handler.common import get_logger -log = get_logger(__name__) +logger = get_logger(__name__) class ValidationSource(Enum): @@ -26,13 +30,10 @@ class ValidationSource(Enum): class ValidationIssue: """Represents a single validation issue.""" - field: str + field_name: str message: str issue_type: str = "error" - def __str__(self) -> str: - return f"[{self.issue_type}] {self.field}: {self.message}" - @dataclass class ValidationResult: @@ -43,7 +44,6 @@ class ValidationResult: source_type: ValidationSource agent_card: AgentCard | None = None issues: list[ValidationIssue] = field(default_factory=list) - warnings: list[ValidationIssue] = field(default_factory=list) raw_data: dict[str, Any] | None = None @property @@ -59,186 +59,76 @@ def agent_name(self) -> str: def protocol_version(self) -> str: """Get the protocol version if available.""" if self.agent_card: - return self.agent_card.protocol_version or "1.0" + return self.agent_card.protocol_version or "Unknown" if self.raw_data: - return self.raw_data.get("protocolVersion", "1.0") + return self.raw_data.get("protocolVersion", "Unknown") return "Unknown" -def _parse_pydantic_error(error: ValidationError) -> list[ValidationIssue]: +def _parse_validation_errors(error: ValidationError) -> list[ValidationIssue]: """Parse Pydantic validation errors into ValidationIssues.""" issues = [] - for err in error.errors(): - field_path = ".".join(str(loc) for loc in err["loc"]) - message = err["msg"] - issue_type = err["type"] + for detail in error.errors(): + field_path = ".".join(str(loc) for loc in detail["loc"]) issues.append( ValidationIssue( - field=field_path or "root", - message=message, - issue_type=issue_type, + field_name=field_path or "root", + message=detail["msg"], + issue_type=detail["type"], ) ) return issues -def _check_best_practices(card: AgentCard) -> list[ValidationIssue]: - """Check for best practices and generate warnings. - - Note: In A2A v0.3.0, the following are REQUIRED fields and validated by Pydantic: - - name, description, url, version - - capabilities, defaultInputModes, defaultOutputModes, skills - - preferredTransport (defaults to JSONRPC in SDK) - - This function only warns about optional fields that improve agent discoverability. - """ - warnings = [] - - if not card.provider: - warnings.append( - ValidationIssue( - field="provider", - message="Agent card should specify a provider for better discoverability", - issue_type="warning", - ) - ) - - if not card.documentation_url: - warnings.append( - ValidationIssue( - field="documentationUrl", - message="Agent card should include documentation URL", - issue_type="warning", - ) - ) - - if not card.icon_url: - warnings.append( - ValidationIssue( - field="iconUrl", - message="Agent card should include an icon URL for UI display", - issue_type="warning", - ) - ) - - if card.skills: - for i, skill in enumerate(card.skills): - if not skill.description: - warnings.append( - ValidationIssue( - field=f"skills[{i}].description", - message=f"Skill '{skill.name}' should have a description", - issue_type="warning", - ) - ) - if not skill.examples or len(skill.examples) == 0: - warnings.append( - ValidationIssue( - field=f"skills[{i}].examples", - message=f"Skill '{skill.name}' should include example prompts", - issue_type="warning", - ) - ) - - if not card.additional_interfaces or len(card.additional_interfaces) == 0: - warnings.append( - ValidationIssue( - field="additionalInterfaces", - message="Consider declaring additional transport interfaces for flexibility", - issue_type="warning", - ) - ) - - return warnings - - -def validate_agent_card_data( - data: dict[str, Any], source: str, source_type: ValidationSource +async def validate_agent_card_from_url( + agent_url: str, + http_client: httpx.AsyncClient | None = None, ) -> ValidationResult: - """Validate agent card data against the A2A protocol schema. + """Fetch and validate an agent card from a URL using the A2A SDK. Args: - data: Raw agent card data as a dictionary - source: The source (URL or file path) of the data - source_type: Whether the source is a URL or file + agent_url: The base URL of the agent + http_client: Optional HTTP client to use Returns: ValidationResult with validation status and any issues """ - log.debug("Validating agent card data from %s", source) + logger.info("Validating agent card from URL: %s", agent_url) - try: - card = AgentCard.model_validate(data) - log.info("Agent card validation successful for %s", card.name) + should_close_client = http_client is None + if http_client is None: + http_client = httpx.AsyncClient(timeout=30) - warnings = _check_best_practices(card) + try: + resolver = A2ACardResolver(http_client, agent_url) + agent_card = await resolver.get_agent_card() + logger.info("Agent card validation successful for %s", agent_card.name) return ValidationResult( valid=True, - source=source, - source_type=source_type, - agent_card=card, - warnings=warnings, - raw_data=data, + source=agent_url, + source_type=ValidationSource.URL, + agent_card=agent_card, ) except ValidationError as e: - log.warning("Agent card validation failed: %s", e) - issues = _parse_pydantic_error(e) - + logger.warning("Agent card validation failed: %s", e) return ValidationResult( valid=False, - source=source, - source_type=source_type, - issues=issues, - raw_data=data, + source=agent_url, + source_type=ValidationSource.URL, + issues=_parse_validation_errors(e), ) - -async def validate_agent_card_from_url( - url: str, - client: httpx.AsyncClient | None = None, - card_path: str | None = None, -) -> ValidationResult: - """Fetch and validate an agent card from a URL. - - Args: - url: The base URL of the agent - client: Optional HTTP client to use - card_path: Optional custom path to the agent card (default: /.well-known/agent.json) - - Returns: - ValidationResult with validation status and any issues - """ - log.info("Validating agent card from URL: %s", url) - - should_close = client is None - if client is None: - client = httpx.AsyncClient(timeout=30) - - try: - base_url = url.rstrip("/") - if card_path: - full_url = f"{base_url}/{card_path.lstrip('/')}" - else: - full_url = f"{base_url}/.well-known/agent-card.json" - - log.debug("Fetching agent card from %s", full_url) - response = await client.get(full_url) - response.raise_for_status() - - data = response.json() - return validate_agent_card_data(data, url, ValidationSource.URL) - except httpx.HTTPStatusError as e: - log.error("HTTP error fetching agent card: %s", e) + logger.error("HTTP error fetching agent card: %s", e) return ValidationResult( valid=False, - source=url, + source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field="http", + field_name="http", message=f"HTTP {e.response.status_code}: {e.response.text[:200]}", issue_type="http_error", ) @@ -246,38 +136,23 @@ async def validate_agent_card_from_url( ) except httpx.RequestError as e: - log.error("Request error fetching agent card: %s", e) + logger.error("Request error fetching agent card: %s", e) return ValidationResult( valid=False, - source=url, + source=agent_url, source_type=ValidationSource.URL, issues=[ ValidationIssue( - field="connection", + field_name="connection", message=str(e), issue_type="connection_error", ) ], ) - except json.JSONDecodeError as e: - log.error("JSON decode error: %s", e) - return ValidationResult( - valid=False, - source=url, - source_type=ValidationSource.URL, - issues=[ - ValidationIssue( - field="json", - message=f"Invalid JSON: {e}", - issue_type="json_error", - ) - ], - ) - finally: - if should_close: - await client.aclose() + if should_close_client: + await http_client.aclose() def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: @@ -290,17 +165,17 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ValidationResult with validation status and any issues """ path = Path(file_path) - log.info("Validating agent card from file: %s", path) + logger.info("Validating agent card from file: %s", path) if not path.exists(): - log.error("File not found: %s", path) + logger.error("File not found: %s", path) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=f"File not found: {path}", issue_type="file_error", ) @@ -308,35 +183,56 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ) if not path.is_file(): - log.error("Path is not a file: %s", path) + logger.error("Path is not a file: %s", path) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=f"Path is not a file: {path}", issue_type="file_error", ) ], ) + card_data: dict[str, Any] | None = None + try: with open(path, encoding="utf-8") as f: - data = json.load(f) + card_data = json.load(f) + + agent_card = AgentCard.model_validate(card_data) + logger.info("Agent card validation successful for %s", agent_card.name) + + return ValidationResult( + valid=True, + source=str(path), + source_type=ValidationSource.FILE, + agent_card=agent_card, + raw_data=card_data, + ) - return validate_agent_card_data(data, str(path), ValidationSource.FILE) + except ValidationError as e: + logger.warning("Agent card validation failed: %s", e) + return ValidationResult( + valid=False, + source=str(path), + source_type=ValidationSource.FILE, + issues=_parse_validation_errors(e), + raw_data=card_data, + ) except json.JSONDecodeError as e: - log.error("JSON decode error: %s", e) + logger.error("JSON decode error: %s", e) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="json", + field_name="json", message=f"Invalid JSON at line {e.lineno}, column {e.colno}: {e.msg}", issue_type="json_error", ) @@ -344,14 +240,14 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ) except PermissionError: - log.error("Permission denied reading file: %s", path) + logger.error("Permission denied reading file: %s", path) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=f"Permission denied: {path}", issue_type="file_error", ) @@ -359,14 +255,14 @@ def validate_agent_card_from_file(file_path: str | Path) -> ValidationResult: ) except OSError as e: - log.error("Error reading file: %s", e) + logger.error("Error reading file: %s", e) return ValidationResult( valid=False, source=str(path), source_type=ValidationSource.FILE, issues=[ ValidationIssue( - field="file", + field_name="file", message=str(e), issue_type="file_error", ) diff --git a/src/a2a_handler/webhook.py b/src/a2a_handler/webhook.py new file mode 100644 index 0000000..cbed26b --- /dev/null +++ b/src/a2a_handler/webhook.py @@ -0,0 +1,176 @@ +"""Webhook server for receiving A2A push notifications. + +Provides an HTTP server for receiving and displaying push notifications from A2A agents. +""" + +import json +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from a2a_handler.common import console, get_logger + +logger = get_logger(__name__) + +DEFAULT_MAX_STORED_NOTIFICATIONS = 100 + + +@dataclass +class PushNotification: + """A received push notification.""" + + timestamp: datetime + task_id: str | None + payload: dict[str, Any] + headers: dict[str, str] + + +@dataclass +class PushNotificationStore: + """In-memory store for received notifications.""" + + notifications: deque[PushNotification] = field( + default_factory=lambda: deque(maxlen=DEFAULT_MAX_STORED_NOTIFICATIONS) + ) + + def add_notification(self, notification: PushNotification) -> None: + """Add a notification to the store.""" + self.notifications.append(notification) + logger.debug( + "Stored notification for task: %s (total: %d)", + notification.task_id, + len(self.notifications), + ) + + def get_all_notifications(self) -> list[PushNotification]: + """Get all stored notifications.""" + return list(self.notifications) + + def clear_all_notifications(self) -> None: + """Clear all stored notifications.""" + notification_count = len(self.notifications) + self.notifications.clear() + logger.info("Cleared %d stored notifications", notification_count) + + +notification_store = PushNotificationStore() + + +async def handle_push_notification(request: Request) -> JSONResponse: + """Handle incoming push notifications from A2A agents.""" + try: + request_payload = await request.json() + except json.JSONDecodeError: + logger.warning("Received invalid JSON in push notification") + return JSONResponse({"error": "Invalid JSON"}, status_code=400) + + request_headers = dict(request.headers) + task_id = request_payload.get("id") or request_payload.get("task_id") + + notification = PushNotification( + timestamp=datetime.now(), + task_id=task_id, + payload=request_payload, + headers=request_headers, + ) + notification_store.add_notification(notification) + + logger.info("Received push notification for task: %s", task_id) + + console.print("\n[bold cyan]Push Notification Received[/bold cyan]") + console.print(f"[dim]Timestamp:[/dim] {notification.timestamp.isoformat()}") + if task_id: + console.print(f"[dim]Task ID:[/dim] {task_id}") + + task_status = request_payload.get("status", {}) + if task_status: + task_state = task_status.get("state", "unknown") + console.print(f"[dim]State:[/dim] {task_state}") + + authentication_token = request_headers.get("x-a2a-notification-token") + if authentication_token: + console.print(f"[dim]Token:[/dim] {authentication_token[:20]}...") + + console.print() + console.print_json(json.dumps(request_payload, indent=2, default=str)) + console.print() + + return JSONResponse({"status": "ok", "received": True}) + + +async def handle_webhook_validation(request: Request) -> JSONResponse: + """Handle GET requests for webhook validation.""" + logger.info("Webhook validation request received") + return JSONResponse({"status": "ok", "message": "Webhook is active"}) + + +async def handle_list_notifications(request: Request) -> JSONResponse: + """List all received notifications.""" + all_notifications = notification_store.get_all_notifications() + logger.debug("Returning %d stored notifications", len(all_notifications)) + return JSONResponse( + { + "count": len(all_notifications), + "notifications": [ + { + "timestamp": notification.timestamp.isoformat(), + "task_id": notification.task_id, + "payload": notification.payload, + } + for notification in all_notifications + ], + } + ) + + +async def handle_clear_notifications(request: Request) -> JSONResponse: + """Clear all stored notifications.""" + notification_store.clear_all_notifications() + return JSONResponse({"status": "ok", "message": "Notifications cleared"}) + + +def create_webhook_application() -> Starlette: + """Create the webhook Starlette application.""" + application_routes = [ + Route("/webhook", handle_push_notification, methods=["POST"]), + Route("/webhook", handle_webhook_validation, methods=["GET"]), + Route("/notifications", handle_list_notifications, methods=["GET"]), + Route("/notifications/clear", handle_clear_notifications, methods=["POST"]), + ] + return Starlette(routes=application_routes) + + +def run_webhook_server( + host: str = "127.0.0.1", + port: int = 9000, +) -> None: + """Start the webhook server. + + Args: + host: Host address to bind to + port: Port number to bind to + """ + console.print(f"\n[bold]Starting webhook server on [url]{host}:{port}[/url][/bold]") + console.print() + console.print("[dim]Endpoints:[/dim]") + console.print(f" POST http://{host}:{port}/webhook - Receive notifications") + console.print(f" GET http://{host}:{port}/webhook - Validation check") + console.print(f" GET http://{host}:{port}/notifications - List received") + console.print(f" POST http://{host}:{port}/notifications/clear - Clear stored") + console.print() + console.print( + f"[bold green]Use this URL for push notifications:[/bold green] " + f"http://{host}:{port}/webhook" + ) + console.print() + + logger.info("Starting webhook server on %s:%d", host, port) + webhook_application = create_webhook_application() + uvicorn.run(webhook_application, host=host, port=port, log_level="warning") diff --git a/tests/__init__.py b/tests/__init__.py index 8b13789..abd12bc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ - +"""Test suite for the Handler package.""" diff --git a/tests/test_formatting.py b/tests/test_formatting.py new file mode 100644 index 0000000..9b14ecf --- /dev/null +++ b/tests/test_formatting.py @@ -0,0 +1,116 @@ +"""Tests for the formatting utilities module.""" + +from a2a_handler.common.formatting import format_field_name, format_value + + +class TestFormatFieldName: + """Tests for format_field_name function.""" + + def test_snake_case(self): + """Test conversion of snake_case to Title Case.""" + assert format_field_name("snake_case") == "Snake Case" + assert format_field_name("some_field_name") == "Some Field Name" + + def test_camel_case(self): + """Test conversion of camelCase to Title Case.""" + assert format_field_name("camelCase") == "Camel Case" + assert format_field_name("someFieldName") == "Some Field Name" + + def test_already_spaced(self): + """Test handling of already spaced strings.""" + assert format_field_name("already spaced") == "Already Spaced" + + def test_single_word(self): + """Test single word remains capitalized.""" + assert format_field_name("name") == "Name" + assert format_field_name("id") == "Id" + + def test_mixed_case_with_underscores(self): + """Test mixed camelCase with underscores.""" + assert format_field_name("some_camelCase_field") == "Some Camel Case Field" + + +class TestFormatValue: + """Tests for format_value function.""" + + def test_none_returns_empty(self): + """Test None returns empty string.""" + assert format_value(None) == "" + + def test_empty_string_returns_empty(self): + """Test empty string returns empty string.""" + assert format_value("") == "" + + def test_empty_list_returns_empty(self): + """Test empty list returns empty string.""" + assert format_value([]) == "" + + def test_empty_dict_returns_empty(self): + """Test empty dict returns empty string.""" + assert format_value({}) == "" + + def test_bool_true(self): + """Test True returns checkmark.""" + assert format_value(True) == "✓" + + def test_bool_false(self): + """Test False returns X.""" + assert format_value(False) == "✗" + + def test_string(self): + """Test string returns itself.""" + assert format_value("hello") == "hello" + assert format_value("hello world") == "hello world" + + def test_integer(self): + """Test integer returns string representation.""" + assert format_value(42) == "42" + assert format_value(0) == "0" + + def test_float(self): + """Test float returns string representation.""" + assert format_value(3.14) == "3.14" + + def test_simple_dict(self): + """Test simple dict formatting.""" + result = format_value({"name": "test"}) + assert "Name:" in result + assert "test" in result + + def test_dict_skips_private_keys(self): + """Test dict skips keys starting with underscore.""" + result = format_value({"name": "test", "_private": "hidden"}) + assert "Name:" in result + assert "_private" not in result + assert "hidden" not in result + + def test_list_of_strings(self): + """Test list of strings formatting.""" + result = format_value(["a", "b", "c"]) + assert "• a" in result + assert "• b" in result + assert "• c" in result + + def test_list_of_dicts_with_name(self): + """Test list of dicts uses name field.""" + result = format_value([{"name": "Item1"}, {"name": "Item2"}]) + assert "Item1" in result + assert "Item2" in result + + def test_list_of_dicts_with_description(self): + """Test list of dicts includes description.""" + result = format_value([{"name": "Item1", "description": "Description1"}]) + assert "Item1" in result + assert "Description1" in result + + def test_nested_dict(self): + """Test nested dict formatting.""" + result = format_value({"outer": {"inner": "value"}}) + assert "Outer:" in result + assert "Inner:" in result + assert "value" in result + + def test_indentation(self): + """Test indentation is applied.""" + result = format_value({"name": "test"}, indent=1) + assert result.startswith(" ") diff --git a/tests/test_service.py b/tests/test_service.py new file mode 100644 index 0000000..c0a0f18 --- /dev/null +++ b/tests/test_service.py @@ -0,0 +1,165 @@ +"""Tests for the A2A service layer module.""" + +from a2a.types import Part, Task, TaskState, TaskStatus, TextPart + +from a2a_handler.service import ( + SendResult, + StreamEvent, + TaskResult, + TERMINAL_TASK_STATES, + extract_text_from_message_parts, +) + + +class TestSendResult: + """Tests for SendResult dataclass.""" + + def test_is_complete_when_completed(self): + """Test is_complete returns True for completed state.""" + result = SendResult(state=TaskState.completed) + assert result.is_complete is True + + def test_is_complete_when_canceled(self): + """Test is_complete returns True for canceled state.""" + result = SendResult(state=TaskState.canceled) + assert result.is_complete is True + + def test_is_complete_when_failed(self): + """Test is_complete returns True for failed state.""" + result = SendResult(state=TaskState.failed) + assert result.is_complete is True + + def test_is_complete_when_rejected(self): + """Test is_complete returns True for rejected state.""" + result = SendResult(state=TaskState.rejected) + assert result.is_complete is True + + def test_is_complete_when_working(self): + """Test is_complete returns False for working state.""" + result = SendResult(state=TaskState.working) + assert result.is_complete is False + + def test_is_complete_when_no_state(self): + """Test is_complete returns False when state is None.""" + result = SendResult() + assert result.is_complete is False + + def test_needs_input_when_input_required(self): + """Test needs_input returns True for input_required state.""" + result = SendResult(state=TaskState.input_required) + assert result.needs_input is True + + def test_needs_input_when_working(self): + """Test needs_input returns False for working state.""" + result = SendResult(state=TaskState.working) + assert result.needs_input is False + + def test_needs_input_when_no_state(self): + """Test needs_input returns False when state is None.""" + result = SendResult() + assert result.needs_input is False + + +class TestStreamEvent: + """Tests for StreamEvent dataclass.""" + + def test_create_message_event(self): + """Test creating a message event.""" + event = StreamEvent( + event_type="message", + context_id="ctx-123", + text="Hello, world!", + ) + + assert event.event_type == "message" + assert event.context_id == "ctx-123" + assert event.text == "Hello, world!" + + def test_create_status_event(self): + """Test creating a status event.""" + event = StreamEvent( + event_type="status", + task_id="task-456", + state=TaskState.working, + ) + + assert event.event_type == "status" + assert event.task_id == "task-456" + assert event.state == TaskState.working + + +class TestTaskResult: + """Tests for TaskResult dataclass.""" + + def test_create_task_result(self): + """Test creating a task result.""" + mock_task = Task( + id="task-123", + context_id="ctx-123", + status=TaskStatus(state=TaskState.completed), + ) + + result = TaskResult( + task=mock_task, + task_id="task-123", + state=TaskState.completed, + context_id="ctx-123", + text="Task completed successfully", + ) + + assert result.task_id == "task-123" + assert result.state == TaskState.completed + assert result.text == "Task completed successfully" + + +class TestExtractTextFromMessageParts: + """Tests for extract_text_from_message_parts function.""" + + def test_extract_from_none(self): + """Test extracting from None returns empty string.""" + result = extract_text_from_message_parts(None) + assert result == "" + + def test_extract_from_empty_list(self): + """Test extracting from empty list returns empty string.""" + result = extract_text_from_message_parts([]) + assert result == "" + + def test_extract_from_text_part_with_root(self): + """Test extracting from TextPart wrapped in Part.""" + parts = [Part(root=TextPart(text="Hello, world!"))] + result = extract_text_from_message_parts(parts) + assert result == "Hello, world!" + + def test_extract_multiple_parts(self): + """Test extracting from multiple parts joins with newlines.""" + parts = [ + Part(root=TextPart(text="First line")), + Part(root=TextPart(text="Second line")), + ] + result = extract_text_from_message_parts(parts) + assert result == "First line\nSecond line" + + +class TestTerminalStates: + """Tests for terminal state constants.""" + + def test_terminal_states_include_completed(self): + """Test that completed is a terminal state.""" + assert TaskState.completed in TERMINAL_TASK_STATES + + def test_terminal_states_include_canceled(self): + """Test that canceled is a terminal state.""" + assert TaskState.canceled in TERMINAL_TASK_STATES + + def test_terminal_states_include_failed(self): + """Test that failed is a terminal state.""" + assert TaskState.failed in TERMINAL_TASK_STATES + + def test_terminal_states_include_rejected(self): + """Test that rejected is a terminal state.""" + assert TaskState.rejected in TERMINAL_TASK_STATES + + def test_working_is_not_terminal(self): + """Test that working is not a terminal state.""" + assert TaskState.working not in TERMINAL_TASK_STATES diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..7319f66 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,185 @@ +"""Tests for the session state management module.""" + +import tempfile +from pathlib import Path + +from a2a_handler.session import AgentSession, SessionStore + + +class TestAgentSession: + """Tests for AgentSession dataclass.""" + + def test_create_session_with_url_only(self): + """Test creating a session with just the URL.""" + session = AgentSession(agent_url="http://localhost:8000") + + assert session.agent_url == "http://localhost:8000" + assert session.context_id is None + assert session.task_id is None + + def test_create_session_with_all_fields(self): + """Test creating a session with all fields.""" + session = AgentSession( + agent_url="http://localhost:8000", + context_id="ctx-123", + task_id="task-456", + ) + + assert session.agent_url == "http://localhost:8000" + assert session.context_id == "ctx-123" + assert session.task_id == "task-456" + + def test_update_context_id(self): + """Test updating context_id.""" + session = AgentSession(agent_url="http://localhost:8000") + session.update(context_id="new-context") + + assert session.context_id == "new-context" + assert session.task_id is None + + def test_update_task_id(self): + """Test updating task_id.""" + session = AgentSession(agent_url="http://localhost:8000") + session.update(task_id="new-task") + + assert session.context_id is None + assert session.task_id == "new-task" + + def test_update_both_ids(self): + """Test updating both context_id and task_id.""" + session = AgentSession(agent_url="http://localhost:8000") + session.update(context_id="ctx-1", task_id="task-1") + + assert session.context_id == "ctx-1" + assert session.task_id == "task-1" + + def test_update_preserves_existing_values_when_none_passed(self): + """Test that update preserves existing values when None is passed.""" + session = AgentSession( + agent_url="http://localhost:8000", + context_id="existing-ctx", + task_id="existing-task", + ) + session.update() + + assert session.context_id == "existing-ctx" + assert session.task_id == "existing-task" + + +class TestSessionStore: + """Tests for SessionStore.""" + + def test_get_creates_new_session(self): + """Test that get creates a new session if none exists.""" + store = SessionStore() + session = store.get("http://localhost:8000") + + assert session.agent_url == "http://localhost:8000" + assert "http://localhost:8000" in store.sessions + + def test_get_returns_existing_session(self): + """Test that get returns existing session.""" + store = SessionStore() + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000", + context_id="existing-ctx", + ) + + session = store.get("http://localhost:8000") + assert session.context_id == "existing-ctx" + + def test_update_creates_and_updates_session(self): + """Test that update creates and updates session.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + session = store.update( + "http://localhost:8000", + context_id="new-ctx", + task_id="new-task", + ) + + assert session.context_id == "new-ctx" + assert session.task_id == "new-task" + + def test_clear_specific_session(self): + """Test clearing a specific session.""" + store = SessionStore() + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000" + ) + store.sessions["http://localhost:9000"] = AgentSession( + agent_url="http://localhost:9000" + ) + + with tempfile.TemporaryDirectory() as temp_directory: + store.session_directory = Path(temp_directory) + store.clear("http://localhost:8000") + + assert "http://localhost:8000" not in store.sessions + assert "http://localhost:9000" in store.sessions + + def test_clear_all_sessions(self): + """Test clearing all sessions.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000" + ) + store.sessions["http://localhost:9000"] = AgentSession( + agent_url="http://localhost:9000" + ) + + store.clear() + + assert len(store.sessions) == 0 + + def test_list_all_sessions(self): + """Test listing all sessions.""" + store = SessionStore() + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000" + ) + store.sessions["http://localhost:9000"] = AgentSession( + agent_url="http://localhost:9000" + ) + + all_sessions = store.list_all() + assert len(all_sessions) == 2 + + def test_save_and_load_sessions(self): + """Test saving and loading sessions from disk.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + store.sessions["http://localhost:8000"] = AgentSession( + agent_url="http://localhost:8000", + context_id="ctx-123", + task_id="task-456", + ) + store.save() + + new_store = SessionStore(session_directory=Path(temp_directory)) + new_store.load() + + assert "http://localhost:8000" in new_store.sessions + loaded_session = new_store.sessions["http://localhost:8000"] + assert loaded_session.context_id == "ctx-123" + assert loaded_session.task_id == "task-456" + + def test_load_nonexistent_file(self): + """Test loading from nonexistent file does nothing.""" + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + store.load() + + assert len(store.sessions) == 0 + + def test_load_invalid_json(self): + """Test loading invalid JSON file handles gracefully.""" + with tempfile.TemporaryDirectory() as temp_directory: + session_file = Path(temp_directory) / "sessions.json" + session_file.write_text("not valid json {{{") + + store = SessionStore(session_directory=Path(temp_directory)) + store.load() + + assert len(store.sessions) == 0 diff --git a/tests/test_tui.py b/tests/test_tui.py index 4afae9c..4177d6f 100644 --- a/tests/test_tui.py +++ b/tests/test_tui.py @@ -1,4 +1,7 @@ +"""Tests for the TUI application.""" + import pytest + from a2a_handler.tui import HandlerTUI diff --git a/tests/test_validation.py b/tests/test_validation.py index 49736f9..6c06284 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,12 +1,13 @@ -"""Tests for A2A protocol validation.""" +"""Tests for the agent card validation module.""" import json import tempfile from pathlib import Path +from a2a.types import AgentCard + from a2a_handler.validation import ( ValidationSource, - validate_agent_card_data, validate_agent_card_from_file, ) @@ -36,58 +37,38 @@ def _minimal_valid_agent_card() -> dict: } -class TestValidateAgentCardData: - """Tests for validate_agent_card_data function.""" +class TestAgentCardValidation: + """Tests for agent card validation using the A2A SDK.""" def test_valid_minimal_card(self): """Test validation of a minimal valid agent card.""" data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) + card = AgentCard.model_validate(data) - assert result.valid is True - assert result.agent_card is not None - assert result.agent_card.name == "Test Agent" - assert len(result.issues) == 0 + assert card.name == "Test Agent" + assert card.description == "A test agent" + assert len(card.skills) == 1 def test_missing_required_field(self): """Test validation fails when required field is missing.""" data = {"url": "http://localhost:8000"} - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - - assert result.valid is False - assert len(result.issues) > 0 - field_names = [i.field for i in result.issues] - assert "name" in field_names - - def test_warnings_for_optional_fields(self): - """Test warnings are generated for missing optional fields.""" - data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.valid is True - warning_fields = [w.field for w in result.warnings] - assert "provider" in warning_fields - assert "documentationUrl" in warning_fields - assert "iconUrl" in warning_fields + try: + AgentCard.model_validate(data) + assert False, "Expected validation to fail" + except Exception: + pass def test_skill_without_tags_fails_validation(self): """Test that skills without tags fail validation (tags are required in v0.3.0).""" data = _minimal_valid_agent_card() data["skills"] = [{"id": "test", "name": "Test", "description": "Test desc"}] - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.valid is False - issue_fields = [i.field for i in result.issues] - assert any("skills" in f and "tags" in f for f in issue_fields) - - def test_skill_without_examples_generates_warning(self): - """Test that skills without examples generate warnings.""" - data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - - assert result.valid is True - warning_fields = [w.field for w in result.warnings] - assert any("examples" in f for f in warning_fields) + try: + AgentCard.model_validate(data) + assert False, "Expected validation to fail" + except Exception: + pass class TestValidateAgentCardFromFile: @@ -147,30 +128,54 @@ class TestValidationResult: def test_agent_name_from_card(self): """Test agent_name property returns name from agent card.""" data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.agent_name == "Test Agent" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.agent_name == "Test Agent" + + Path(f.name).unlink() def test_agent_name_from_raw_data(self): """Test agent_name property returns name from raw data when card is None.""" data = {"name": "Raw Agent", "url": "invalid"} - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.valid is False - assert result.agent_name == "Raw Agent" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.valid is False + assert result.agent_name == "Raw Agent" + + Path(f.name).unlink() def test_protocol_version_from_sdk(self): """Test protocol_version returns the SDK default version.""" data = _minimal_valid_agent_card() - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.protocol_version is not None - assert len(result.protocol_version) > 0 + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.protocol_version is not None + assert len(result.protocol_version) > 0 + + Path(f.name).unlink() def test_protocol_version_explicit(self): """Test protocol_version returns explicit version when set.""" data = _minimal_valid_agent_card() data["protocolVersion"] = "2.0" - result = validate_agent_card_data(data, "test", ValidationSource.FILE) - assert result.protocol_version == "2.0" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + + result = validate_agent_card_from_file(f.name) + assert result.protocol_version == "2.0" + + Path(f.name).unlink() diff --git a/tests/test_webhook.py b/tests/test_webhook.py new file mode 100644 index 0000000..d346a80 --- /dev/null +++ b/tests/test_webhook.py @@ -0,0 +1,158 @@ +"""Tests for the webhook server module.""" + +from datetime import datetime + +import pytest +from starlette.testclient import TestClient + +from a2a_handler.webhook import ( + PushNotification, + PushNotificationStore, + create_webhook_application, +) + + +class TestPushNotification: + """Tests for PushNotification dataclass.""" + + def test_create_notification(self): + """Test creating a notification.""" + timestamp = datetime.now() + notification = PushNotification( + timestamp=timestamp, + task_id="task-123", + payload={"status": {"state": "completed"}}, + headers={"content-type": "application/json"}, + ) + + assert notification.timestamp == timestamp + assert notification.task_id == "task-123" + assert notification.payload["status"]["state"] == "completed" + assert notification.headers["content-type"] == "application/json" + + def test_create_notification_without_task_id(self): + """Test creating a notification without task_id.""" + notification = PushNotification( + timestamp=datetime.now(), + task_id=None, + payload={}, + headers={}, + ) + + assert notification.task_id is None + + +class TestPushNotificationStore: + """Tests for PushNotificationStore.""" + + def test_add_notification(self): + """Test adding a notification to the store.""" + store = PushNotificationStore() + notification = PushNotification( + timestamp=datetime.now(), + task_id="task-123", + payload={}, + headers={}, + ) + + store.add_notification(notification) + assert len(store.notifications) == 1 + + def test_get_all_notifications(self): + """Test getting all notifications.""" + store = PushNotificationStore() + notification1 = PushNotification( + timestamp=datetime.now(), + task_id="task-1", + payload={}, + headers={}, + ) + notification2 = PushNotification( + timestamp=datetime.now(), + task_id="task-2", + payload={}, + headers={}, + ) + + store.add_notification(notification1) + store.add_notification(notification2) + + all_notifications = store.get_all_notifications() + assert len(all_notifications) == 2 + + def test_clear_all_notifications(self): + """Test clearing all notifications.""" + store = PushNotificationStore() + notification = PushNotification( + timestamp=datetime.now(), + task_id="task-123", + payload={}, + headers={}, + ) + + store.add_notification(notification) + store.clear_all_notifications() + + assert len(store.notifications) == 0 + + +class TestWebhookApplication: + """Tests for the webhook Starlette application.""" + + @pytest.fixture + def client(self): + """Create a test client for the webhook application.""" + application = create_webhook_application() + return TestClient(application) + + def test_webhook_validation_get(self, client): + """Test GET request for webhook validation.""" + response = client.get("/webhook") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "message" in data + + def test_webhook_receive_notification(self, client): + """Test POST request to receive notification.""" + payload = { + "id": "task-123", + "status": {"state": "completed"}, + } + + response = client.post("/webhook", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["received"] is True + + def test_webhook_invalid_json(self, client): + """Test POST request with invalid JSON.""" + response = client.post( + "/webhook", + content="not valid json", + headers={"content-type": "application/json"}, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + + def test_list_notifications(self, client): + """Test listing received notifications.""" + response = client.get("/notifications") + + assert response.status_code == 200 + data = response.json() + assert "count" in data + assert "notifications" in data + + def test_clear_notifications(self, client): + """Test clearing notifications.""" + response = client.post("/notifications/clear") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok"