From d45e875712e619ab8b3b9d9b1d48cc8d3e460486 Mon Sep 17 00:00:00 2001 From: Dobes Vandermeer Date: Wed, 18 Mar 2026 21:36:17 -0700 Subject: [PATCH] feat(auth): add mTLS and custom header support for A2A connections Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- src/a2a_handler/auth.py | 95 ++++++++- src/a2a_handler/cli/_helpers.py | 11 +- src/a2a_handler/cli/auth.py | 113 +++++++++-- src/a2a_handler/cli/card.py | 15 +- src/a2a_handler/cli/message.py | 53 ++++- src/a2a_handler/cli/task.py | 10 +- src/a2a_handler/mcp/server.py | 202 +++++++++++++++---- src/a2a_handler/service.py | 19 +- src/a2a_handler/tui/app.py | 35 +++- src/a2a_handler/tui/app.tcss | 21 +- src/a2a_handler/tui/components/auth.py | 75 ++++++- src/a2a_handler/tui/components/card.py | 75 ++++++- src/a2a_handler/tui/components/contact.py | 5 + tests/test_auth.py | 172 ++++++++++++++++ tests/test_cli_auth.py | 228 +++++++++++++++++++++- tests/test_cli_card.py | 6 +- tests/test_session.py | 42 ++++ 17 files changed, 1057 insertions(+), 120 deletions(-) diff --git a/src/a2a_handler/auth.py b/src/a2a_handler/auth.py index f688306..c82ca01 100644 --- a/src/a2a_handler/auth.py +++ b/src/a2a_handler/auth.py @@ -1,11 +1,13 @@ """Authentication support for A2A protocol. Handles credential storage and HTTP authentication header generation. -Currently supports API key and HTTP bearer authentication schemes. +Supports API key, HTTP bearer, and mTLS (mutual TLS) authentication schemes. """ +import ssl from dataclasses import dataclass from enum import Enum +from pathlib import Path class AuthType(str, Enum): @@ -13,6 +15,7 @@ class AuthType(str, Enum): API_KEY = "api_key" BEARER = "bearer" + MTLS = "mtls" @dataclass @@ -23,8 +26,12 @@ class AuthCredentials: """ auth_type: AuthType - value: str + value: str = "" header_name: str | None = None # For API key: custom header name + cert_path: str | None = None # For mTLS: client certificate path + key_path: str | None = None # For mTLS: client private key path + ca_cert_path: str | None = None # For mTLS: CA certificate path + custom_headers: dict[str, str] | None = None # Additional headers for any auth type def to_headers(self) -> dict[str, str]: """Generate HTTP headers for this credential. @@ -32,31 +39,78 @@ def to_headers(self) -> dict[str, str]: Returns: Dictionary of headers to include in requests """ - if self.auth_type == AuthType.BEARER: - return {"Authorization": f"Bearer {self.value}"} + headers: dict[str, str] = {} + if self.auth_type == AuthType.BEARER and self.value: + headers["Authorization"] = f"Bearer {self.value}" elif self.auth_type == AuthType.API_KEY: header = self.header_name or "X-API-Key" - return {header: self.value} - return {} - - def to_dict(self) -> dict[str, str | None]: + headers[header] = self.value + if self.custom_headers: + headers.update(self.custom_headers) + return headers + + def build_ssl_context(self) -> ssl.SSLContext: + """Build an SSL context for mTLS client certificate authentication.""" + if self.auth_type != AuthType.MTLS: + raise ValueError("SSL context can only be built for mTLS credentials") + if not self.cert_path or not self.key_path: + raise ValueError("cert_path and key_path are required for mTLS") + + if self.ca_cert_path: + ctx = ssl.create_default_context(cafile=self.ca_cert_path) + else: + ctx = ssl.create_default_context() + + ctx.load_cert_chain(certfile=self.cert_path, keyfile=self.key_path) + return ctx + + def to_dict(self) -> dict: """Serialize credentials for storage.""" - return { + data: dict = { "auth_type": self.auth_type.value, "value": self.value, "header_name": self.header_name, } + if self.cert_path: + data["cert_path"] = self.cert_path + if self.key_path: + data["key_path"] = self.key_path + if self.ca_cert_path: + data["ca_cert_path"] = self.ca_cert_path + if self.custom_headers: + data["custom_headers"] = self.custom_headers + return data @classmethod - def from_dict(cls, data: dict[str, str | None]) -> "AuthCredentials": + def from_dict(cls, data: dict) -> "AuthCredentials": """Deserialize credentials from storage.""" + custom_headers_raw = data.get("custom_headers") + custom_headers = ( + dict(custom_headers_raw) if isinstance(custom_headers_raw, dict) else None + ) return cls( auth_type=AuthType(data["auth_type"]), value=data.get("value") or "", header_name=data.get("header_name"), + cert_path=data.get("cert_path"), + key_path=data.get("key_path"), + ca_cert_path=data.get("ca_cert_path"), + custom_headers=custom_headers, ) +def parse_header_string(header: str) -> tuple[str, str]: + """Parse a 'Name: Value' header string into a (name, value) tuple.""" + if ":" not in header: + raise ValueError(f"Invalid header format (expected 'Name: Value'): {header}") + name, _, value = header.partition(":") + name = name.strip() + value = value.strip() + if not name: + raise ValueError(f"Empty header name in: {header}") + return name, value + + def create_bearer_auth(token: str) -> AuthCredentials: """Create bearer token authentication.""" return AuthCredentials(auth_type=AuthType.BEARER, value=token) @@ -77,3 +131,24 @@ def create_api_key_auth( value=key, header_name=header_name, ) + + +def create_mtls_auth( + cert_path: str, + key_path: str, + ca_cert_path: str | None = None, +) -> AuthCredentials: + """Create mTLS (mutual TLS) client certificate authentication.""" + if not Path(cert_path).is_file(): + raise FileNotFoundError(f"Client certificate not found: {cert_path}") + if not Path(key_path).is_file(): + raise FileNotFoundError(f"Client private key not found: {key_path}") + if ca_cert_path and not Path(ca_cert_path).is_file(): + raise FileNotFoundError(f"CA certificate not found: {ca_cert_path}") + + return AuthCredentials( + auth_type=AuthType.MTLS, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + ) diff --git a/src/a2a_handler/cli/_helpers.py b/src/a2a_handler/cli/_helpers.py index d068d97..826730d 100644 --- a/src/a2a_handler/cli/_helpers.py +++ b/src/a2a_handler/cli/_helpers.py @@ -8,6 +8,7 @@ A2AClientTimeoutError, ) +from a2a_handler.auth import AuthCredentials, AuthType from a2a_handler.common import Output, get_logger from a2a_handler.common.input_validation import InputValidationError @@ -15,8 +16,16 @@ log = get_logger(__name__) -def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: +def build_http_client( + timeout: int = TIMEOUT, + credentials: AuthCredentials | None = None, +) -> httpx.AsyncClient: """Build an HTTP client with the specified timeout.""" + if credentials and credentials.auth_type == AuthType.MTLS: + return httpx.AsyncClient( + timeout=timeout, + verify=credentials.build_ssl_context(), + ) return httpx.AsyncClient(timeout=timeout) diff --git a/src/a2a_handler/cli/auth.py b/src/a2a_handler/cli/auth.py index bbe67fa..d3f411f 100644 --- a/src/a2a_handler/cli/auth.py +++ b/src/a2a_handler/cli/auth.py @@ -4,7 +4,13 @@ import rich_click as click -from a2a_handler.auth import AuthType, create_api_key_auth, create_bearer_auth +from a2a_handler.auth import ( + AuthType, + create_api_key_auth, + create_bearer_auth, + create_mtls_auth, + parse_header_string, +) from a2a_handler.common import Output from a2a_handler.common.input_validation import ( InputValidationError, @@ -31,15 +37,34 @@ def auth() -> None: default="X-API-Key", help="Header name for API key (default: X-API-Key)", ) +@click.option("--cert", "cert_path", help="Client certificate path for mTLS (PEM)") +@click.option("--key", "key_path", help="Client private key path for mTLS (PEM)") +@click.option( + "--ca-cert", + "ca_cert_path", + help="CA certificate path for mTLS server verification (PEM)", +) +@click.option( + "--header", + "-H", + "headers", + multiple=True, + help="Custom header (repeatable, format: 'Name: Value')", +) def auth_set( agent_url: str, bearer_token: Optional[str], api_key: Optional[str], api_key_header: str, + cert_path: Optional[str], + key_path: Optional[str], + ca_cert_path: Optional[str], + headers: tuple[str, ...], ) -> None: """Set authentication credentials for an agent. - Provide either --bearer or --api-key (not both). + Provide --bearer, --api-key, or --cert/--key for mTLS. Custom headers + can be added to any auth method with --header/-H. """ output = Output() try: @@ -53,24 +78,63 @@ def auth_set( handle_validation_error(error, output) raise click.Abort() from error - if bearer_token and api_key: - output.error("Provide either --bearer or --api-key, not both") + custom_headers: dict[str, str] | None = None + if headers: + custom_headers = {} + for h in headers: + try: + name, value = parse_header_string(h) + reject_control_chars(name, "header name") + reject_control_chars(value, "header value") + custom_headers[name] = value + except (ValueError, InputValidationError) as e: + output.error(str(e)) + raise click.Abort() from e + + has_mtls = cert_path or key_path + method_count = sum(bool(x) for x in [bearer_token, api_key, has_mtls]) + + if method_count > 1: + output.error( + "Provide only one auth method: --bearer, --api-key, or --cert/--key" + ) raise click.Abort() - if not bearer_token and not api_key: - output.error("Provide --bearer or --api-key") + if method_count == 0 and not custom_headers: + output.error("Provide --bearer, --api-key, --cert/--key, or --header") raise click.Abort() - if bearer_token: + if has_mtls: + if not cert_path or not key_path: + output.error("mTLS requires both --cert and --key") + raise click.Abort() + try: + credentials = create_mtls_auth(cert_path, key_path, ca_cert_path) + except FileNotFoundError as e: + output.error(str(e)) + raise click.Abort() from e + auth_type_display = "mTLS client certificate" + elif bearer_token: credentials = create_bearer_auth(bearer_token) auth_type_display = "Bearer token" - else: - credentials = create_api_key_auth(api_key or "", header_name=api_key_header) + elif api_key: + credentials = create_api_key_auth(api_key, header_name=api_key_header) auth_type_display = f"API key (header: {api_key_header})" + else: + from a2a_handler.auth import AuthCredentials + + credentials = AuthCredentials(auth_type=AuthType.BEARER) + auth_type_display = "Custom headers only" + + credentials.custom_headers = custom_headers set_credentials(agent_url, credentials) - output.success(f"Set {auth_type_display} for {agent_url}") + parts = [auth_type_display] + if custom_headers: + header_names = ", ".join(custom_headers.keys()) + parts.append(f"+ headers: {header_names}") + output.success(f"Set {' '.join(parts)} for {agent_url}") @auth.command("show") @@ -93,15 +157,26 @@ def auth_show(agent_url: str) -> None: return output.field("Type", credentials.auth_type.value) - masked_value = ( - f"{credentials.value[:4]}...{credentials.value[-4:]}" - if len(credentials.value) > 8 - else "****" - ) - output.field("Value", masked_value) - - if credentials.auth_type == AuthType.API_KEY: - output.field("Header", credentials.header_name or "X-API-Key") + + if credentials.auth_type == AuthType.MTLS: + output.field("Certificate", credentials.cert_path or "") + output.field("Private Key", credentials.key_path or "") + if credentials.ca_cert_path: + output.field("CA Certificate", credentials.ca_cert_path) + else: + masked_value = ( + f"{credentials.value[:4]}...{credentials.value[-4:]}" + if len(credentials.value) > 8 + else "****" + ) + output.field("Value", masked_value) + + if credentials.auth_type == AuthType.API_KEY: + output.field("Header", credentials.header_name or "X-API-Key") + + if credentials.custom_headers: + for name, value in credentials.custom_headers.items(): + output.field(f"Header: {name}", value) @auth.command("clear") diff --git a/src/a2a_handler/cli/card.py b/src/a2a_handler/cli/card.py index 9454dff..dcf3f2d 100644 --- a/src/a2a_handler/cli/card.py +++ b/src/a2a_handler/cli/card.py @@ -33,10 +33,7 @@ def card() -> None: @card.command("get") @click.argument("agent_url") -@click.option( - "--authenticated", "-a", is_flag=True, help="Request authenticated extended card" -) -def card_get(agent_url: str, authenticated: bool) -> None: +def card_get(agent_url: str) -> None: """Retrieve an agent's card.""" output = Output() try: @@ -46,13 +43,11 @@ def card_get(agent_url: str, authenticated: bool) -> None: raise click.Abort() from error log.info("Fetching agent card from %s", agent_url) - credentials = get_credentials(agent_url) if authenticated else None - if authenticated and credentials is None: - log.warning("No saved credentials found for %s", agent_url) + credentials = get_credentials(agent_url) async def do_get() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) card_data = await service.get_card() log.info("Retrieved card for agent: %s", card_data.name) @@ -92,9 +87,11 @@ def card_validate(source: str) -> None: handle_validation_error(error, output) raise click.Abort() from error + credentials = get_credentials(source) if is_url else None + async def do_validate() -> None: if is_url: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: result = await validate_agent_card_from_url(source, http_client) else: result = validate_agent_card_from_file(source) diff --git a/src/a2a_handler/cli/message.py b/src/a2a_handler/cli/message.py index ddb02a8..ceadc6e 100644 --- a/src/a2a_handler/cli/message.py +++ b/src/a2a_handler/cli/message.py @@ -1,12 +1,19 @@ """Message commands for sending messages to A2A agents.""" import asyncio +from dataclasses import replace from typing import Any from typing import Optional import rich_click as click -from a2a_handler.auth import AuthCredentials, create_api_key_auth, create_bearer_auth +from a2a_handler.auth import ( + AuthCredentials, + AuthType, + create_api_key_auth, + create_bearer_auth, + parse_header_string, +) from a2a_handler.common import Output, get_logger from a2a_handler.common.input_validation import ( InputValidationError, @@ -49,6 +56,13 @@ def message() -> None: @click.option("--push-token", help="Authentication token for push notifications") @click.option("--bearer", "-b", "bearer_token", help="Bearer token (overrides saved)") @click.option("--api-key", "-k", help="API key (overrides saved)") +@click.option( + "--header", + "-H", + "headers", + multiple=True, + help="Custom header (repeatable, format: 'Name: Value')", +) def message_send( agent_url: str, text: Optional[str], @@ -61,6 +75,7 @@ def message_send( push_token: Optional[str], bearer_token: Optional[str], api_key: Optional[str], + headers: tuple[str, ...] = (), ) -> None: """Send a message to an agent and receive a response.""" output = Output() @@ -153,6 +168,19 @@ def message_send( context_id = session.context_id log.info("Using saved context: %s", context_id) + custom_headers: dict[str, str] | None = None + if headers: + custom_headers = {} + for h in headers: + try: + name, value = parse_header_string(h) + reject_control_chars(name, "header name") + reject_control_chars(value, "header value") + custom_headers[name] = value + except (ValueError, InputValidationError) as e: + output.error(str(e)) + raise click.Abort() from e + credentials: AuthCredentials | None = None if bearer_token: credentials = create_bearer_auth(bearer_token) @@ -161,9 +189,21 @@ def message_send( else: credentials = get_credentials(agent_url) + if custom_headers: + if credentials is None: + credentials = AuthCredentials( + auth_type=AuthType.BEARER, + custom_headers=custom_headers, + ) + else: + credentials = replace(credentials) + merged = dict(credentials.custom_headers or {}) + merged.update(custom_headers) + credentials.custom_headers = merged + async def do_send() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService( http_client, agent_url, @@ -203,6 +243,13 @@ async def do_send() -> None: @click.option("--push-token", help="Authentication token for push notifications") @click.option("--bearer", "-b", "bearer_token", help="Bearer token (overrides saved)") @click.option("--api-key", "-k", help="API key (overrides saved)") +@click.option( + "--header", + "-H", + "headers", + multiple=True, + help="Custom header (repeatable, format: 'Name: Value')", +) @click.pass_context def message_stream( ctx: click.Context, @@ -215,6 +262,7 @@ def message_stream( push_token: Optional[str], bearer_token: Optional[str], api_key: Optional[str], + headers: tuple[str, ...] = (), ) -> None: """Send a message and stream the response in real-time.""" ctx.invoke( @@ -229,6 +277,7 @@ def message_stream( push_token=push_token, bearer_token=bearer_token, api_key=api_key, + headers=headers, ) diff --git a/src/a2a_handler/cli/task.py b/src/a2a_handler/cli/task.py index 1c4e317..d983431 100644 --- a/src/a2a_handler/cli/task.py +++ b/src/a2a_handler/cli/task.py @@ -105,7 +105,7 @@ def task_get( async def do_get() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) result = await service.get_task(task_id, history_length) _format_task_result(result, output) @@ -152,7 +152,7 @@ def task_cancel( async def do_cancel() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) output.dim(f"Canceling task {task_id}...") @@ -205,7 +205,7 @@ def task_resubscribe( async def do_resubscribe() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) output.dim(f"Resubscribing to task {task_id}...") @@ -288,7 +288,7 @@ def notification_set( async def do_set() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) output.dim(f"Setting notification config for task {task_id}...") @@ -352,7 +352,7 @@ def notification_get( async def do_get() -> None: try: - async with build_http_client() as http_client: + async with build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) config = await service.get_push_config(task_id, config_id) diff --git a/src/a2a_handler/mcp/server.py b/src/a2a_handler/mcp/server.py index 481edc2..406c3d9 100644 --- a/src/a2a_handler/mcp/server.py +++ b/src/a2a_handler/mcp/server.py @@ -1,11 +1,18 @@ """MCP server implementation exposing A2A capabilities as tools and resources.""" +from dataclasses import replace from typing import Literal import httpx from mcp.server.fastmcp import FastMCP -from a2a_handler.auth import AuthCredentials, create_api_key_auth, create_bearer_auth +from a2a_handler.auth import ( + AuthCredentials, + AuthType, + create_api_key_auth, + create_bearer_auth, + create_mtls_auth, +) from a2a_handler.common import get_logger from a2a_handler.common.input_validation import ( InputValidationError, @@ -42,8 +49,16 @@ def _validation_error(error: InputValidationError) -> ValueError: return ValueError(f"{error.code}: {error.message}") -def _build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient: +def _build_http_client( + timeout: int = TIMEOUT, + credentials: AuthCredentials | None = None, +) -> httpx.AsyncClient: """Build an HTTP client with the specified timeout.""" + if credentials and credentials.auth_type == AuthType.MTLS: + return httpx.AsyncClient( + timeout=timeout, + verify=credentials.build_ssl_context(), + ) return httpx.AsyncClient(timeout=timeout) @@ -51,13 +66,35 @@ def _resolve_credentials( agent_url: str, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> AuthCredentials | None: """Resolve credentials from explicit args or saved session.""" - if bearer_token: - return create_bearer_auth(bearer_token) - if api_key: - return create_api_key_auth(api_key) - return get_credentials(agent_url) + credentials: AuthCredentials | None = None + if cert_path and key_path: + credentials = create_mtls_auth(cert_path, key_path, ca_cert_path) + elif bearer_token: + credentials = create_bearer_auth(bearer_token) + elif api_key: + credentials = create_api_key_auth(api_key) + else: + credentials = get_credentials(agent_url) + + if custom_headers: + if credentials is None: + credentials = AuthCredentials( + auth_type=AuthType.BEARER, + custom_headers=custom_headers, + ) + else: + credentials = replace(credentials) + merged = dict(credentials.custom_headers or {}) + merged.update(custom_headers) + credentials.custom_headers = merged + + return credentials def create_mcp_server() -> FastMCP: @@ -165,8 +202,10 @@ async def get_agent_card(agent_url: str) -> dict: except InputValidationError as error: raise _validation_error(error) from error - async with _build_http_client() as http_client: - service = A2AService(http_client, agent_url) + credentials = _resolve_credentials(agent_url) + + async with _build_http_client(credentials=credentials) as http_client: + service = A2AService(http_client, agent_url, credentials=credentials) card = await service.get_card() return card.model_dump(exclude_none=True) @@ -180,6 +219,10 @@ async def send_message( use_session: bool = False, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> dict: """Send a message to an A2A agent and receive a response. @@ -225,9 +268,17 @@ async def send_message( context_id = session.context_id logger.info("Using saved context: %s", context_id) - credentials = _resolve_credentials(agent_url, bearer_token, api_key) - - async with _build_http_client() as http_client: + credentials = _resolve_credentials( + agent_url, + bearer_token, + api_key, + cert_path, + key_path, + ca_cert_path, + custom_headers, + ) + + async with _build_http_client(credentials=credentials) as http_client: service = A2AService( http_client, agent_url, @@ -253,6 +304,10 @@ async def get_task( history_length: int | None = None, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> dict: """Get the current status and details of a task. @@ -284,9 +339,17 @@ async def get_task( except InputValidationError as error: raise _validation_error(error) from error - credentials = _resolve_credentials(agent_url, bearer_token, api_key) - - async with _build_http_client() as http_client: + credentials = _resolve_credentials( + agent_url, + bearer_token, + api_key, + cert_path, + key_path, + ca_cert_path, + custom_headers, + ) + + async with _build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) result = await service.get_task(task_id, history_length) @@ -303,6 +366,10 @@ async def cancel_task( task_id: str, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> dict: """Cancel a running task. @@ -332,9 +399,17 @@ async def cancel_task( except InputValidationError as error: raise _validation_error(error) from error - credentials = _resolve_credentials(agent_url, bearer_token, api_key) - - async with _build_http_client() as http_client: + credentials = _resolve_credentials( + agent_url, + bearer_token, + api_key, + cert_path, + key_path, + ca_cert_path, + custom_headers, + ) + + async with _build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) result = await service.cancel_task(task_id) @@ -353,6 +428,10 @@ async def set_task_notification( webhook_token: str | None = None, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> dict: """Configure push notifications for a task. @@ -388,9 +467,17 @@ async def set_task_notification( except InputValidationError as error: raise _validation_error(error) from error - credentials = _resolve_credentials(agent_url, bearer_token, api_key) - - async with _build_http_client() as http_client: + credentials = _resolve_credentials( + agent_url, + bearer_token, + api_key, + cert_path, + key_path, + ca_cert_path, + custom_headers, + ) + + async with _build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) config = await service.set_push_config(task_id, webhook_url, webhook_token) @@ -412,6 +499,10 @@ async def get_task_notification( config_id: str | None = None, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> dict: """Get the push notification configuration for a task. @@ -444,9 +535,17 @@ async def get_task_notification( except InputValidationError as error: raise _validation_error(error) from error - credentials = _resolve_credentials(agent_url, bearer_token, api_key) - - async with _build_http_client() as http_client: + credentials = _resolve_credentials( + agent_url, + bearer_token, + api_key, + cert_path, + key_path, + ca_cert_path, + custom_headers, + ) + + async with _build_http_client(credentials=credentials) as http_client: service = A2AService(http_client, agent_url, credentials=credentials) config = await service.get_push_config(task_id, config_id) @@ -557,36 +656,54 @@ async def set_agent_credentials( agent_url: str, bearer_token: str | None = None, api_key: str | None = None, + cert_path: str | None = None, + key_path: str | None = None, + ca_cert_path: str | None = None, + custom_headers: dict[str, str] | None = None, ) -> dict: """Set authentication credentials for an agent. Saves credentials that will be used for all future requests to this - agent. Either bearer_token or api_key should be provided, not both. + agent. Provide one of: bearer_token, api_key, or cert_path/key_path + for mTLS. Args: agent_url: Base URL of the A2A agent bearer_token: Bearer token for Authorization header api_key: API key for X-API-Key header + cert_path: Client certificate path for mTLS (PEM) + key_path: Client private key path for mTLS (PEM) + ca_cert_path: CA certificate path for mTLS server verification (PEM) Returns: A dictionary containing: - agent_url: The agent URL - - auth_type: Type of auth configured ("bearer" or "api_key") + - auth_type: Type of auth configured ("bearer", "api_key", or "mtls") """ logger.info("Setting credentials for %s", agent_url) try: validate_agent_url(agent_url) - if bearer_token and api_key: + + has_mtls = cert_path or key_path + method_count = sum(bool(x) for x in [bearer_token, api_key, has_mtls]) + + if method_count > 1: raise InputValidationError( code="invalid_auth_arguments", - message="Provide either bearer_token or api_key, not both", + message="Provide only one auth method: bearer_token, api_key, or cert_path/key_path", suggestion="Pass only one auth mechanism per call", ) - if not bearer_token and not api_key: + if method_count == 0 and not custom_headers: raise InputValidationError( code="missing_auth_arguments", - message="Either bearer_token or api_key is required", - suggestion="Provide bearer_token or api_key", + message="Provide bearer_token, api_key, cert_path/key_path, or custom_headers", + suggestion="Provide at least one auth mechanism or custom headers", + ) + if has_mtls and (not cert_path or not key_path): + raise InputValidationError( + code="incomplete_mtls_arguments", + message="mTLS requires both cert_path and key_path", + suggestion="Provide both cert_path and key_path", ) if bearer_token: reject_control_chars(bearer_token, "bearer_token") @@ -595,16 +712,25 @@ async def set_agent_credentials( except InputValidationError as error: raise _validation_error(error) from error - if bearer_token: + credentials: AuthCredentials + if cert_path and key_path: + credentials = create_mtls_auth(cert_path, key_path, ca_cert_path) + auth_type = "mtls" + elif bearer_token: credentials = create_bearer_auth(bearer_token) - set_credentials(agent_url, credentials) - return {"agent_url": agent_url, "auth_type": "bearer"} - if api_key: + auth_type = "bearer" + elif api_key: credentials = create_api_key_auth(api_key) - set_credentials(agent_url, credentials) - return {"agent_url": agent_url, "auth_type": "api_key"} + auth_type = "api_key" + elif custom_headers: + credentials = AuthCredentials(auth_type=AuthType.BEARER) + auth_type = "custom_headers" + else: + raise AssertionError("validated auth inputs should guarantee a return") - raise AssertionError("validated auth inputs should guarantee a return") + credentials.custom_headers = custom_headers + set_credentials(agent_url, credentials) + return {"agent_url": agent_url, "auth_type": auth_type} @mcp.tool() async def clear_agent_credentials(agent_url: str) -> dict: diff --git a/src/a2a_handler/service.py b/src/a2a_handler/service.py index 99021c6..208c7fc 100644 --- a/src/a2a_handler/service.py +++ b/src/a2a_handler/service.py @@ -33,7 +33,7 @@ PREV_AGENT_CARD_WELL_KNOWN_PATH, ) -from a2a_handler.auth import AuthCredentials +from a2a_handler.auth import AuthCredentials, AuthType from a2a_handler.common import get_logger from a2a_handler.common.input_validation import ( reject_control_chars, @@ -270,12 +270,19 @@ def set_credentials(self, credentials: AuthCredentials) -> None: self._applied_auth_headers.clear() self.credentials = credentials - auth_headers = credentials.to_headers() - self.http_client.headers.update(auth_headers) - self._applied_auth_headers = set(auth_headers.keys()) - # Rebuild the SDK client so updated headers are guaranteed to be used. self._cached_client = None - logger.debug("Applied authentication headers: %s", list(auth_headers.keys())) + + auth_headers = credentials.to_headers() + if auth_headers: + self.http_client.headers.update(auth_headers) + self._applied_auth_headers = set(auth_headers.keys()) + + if credentials.auth_type == AuthType.MTLS: + logger.debug("mTLS credentials set (transport-level authentication)") + else: + logger.debug( + "Applied authentication headers: %s", list(auth_headers.keys()) + ) def clear_credentials(self) -> None: """Clear authentication credentials from the service and HTTP client.""" diff --git a/src/a2a_handler/tui/app.py b/src/a2a_handler/tui/app.py index 6c59cb7..dd60f5a 100644 --- a/src/a2a_handler/tui/app.py +++ b/src/a2a_handler/tui/app.py @@ -19,9 +19,10 @@ from textual.screen import Screen from textual.widgets import Button, Footer, Input -from a2a_handler.auth import AuthCredentials +from a2a_handler.auth import AuthCredentials, AuthType from a2a_handler.common import get_theme, install_tui_log_handler, save_theme from a2a_handler.service import A2AService +from a2a_handler.session import get_credentials from a2a_handler.tui.components import ( AgentCardPanel, ContactPanel, @@ -42,8 +43,14 @@ def build_http_client( timeout_seconds: int = DEFAULT_HTTP_TIMEOUT_SECONDS, + credentials: AuthCredentials | None = None, ) -> httpx.AsyncClient: """Build an HTTP client with the specified timeout.""" + if credentials and credentials.auth_type == AuthType.MTLS: + return httpx.AsyncClient( + timeout=timeout_seconds, + verify=credentials.build_ssl_context(), + ) return httpx.AsyncClient(timeout=timeout_seconds) @@ -134,8 +141,9 @@ async def _connect_to_agent( agent_url: str, credentials: AuthCredentials | None = None, ) -> AgentCard: - if not self.http_client: - raise RuntimeError("HTTP client not initialized") + if self.http_client: + await self.http_client.aclose() + self.http_client = build_http_client(credentials=credentials) logger.info("Connecting to agent at %s", agent_url) self._agent_service = A2AService( @@ -152,6 +160,12 @@ def _update_ui_for_connected_state(self, agent_card: AgentCard) -> None: messages_panel = self.query_one("#messages-container", TabbedMessagesPanel) messages_panel.update_message_count() + @on(AgentCardPanel.AgentSelected) + async def handle_agent_selected(self, event: AgentCardPanel.AgentSelected) -> None: + contact_panel = self.query_one("#contact-container", ContactPanel) + contact_panel.set_url(event.agent_url) + await self._do_connect(event.agent_url, get_credentials(event.agent_url)) + @on(Button.Pressed, "#connect-btn") async def handle_connect_button(self) -> None: contact_panel = self.query_one("#contact-container", ContactPanel) @@ -163,11 +177,22 @@ async def handle_connect_button(self) -> None: messages_panel.add_system_message("Please enter an agent URL") return + messages_panel = self.query_one("#messages-container", TabbedMessagesPanel) + credentials = messages_panel.get_auth_credentials() + if credentials is None: + credentials = get_credentials(agent_url) + + await self._do_connect(agent_url, credentials) + + async def _do_connect( + self, + agent_url: str, + credentials: AuthCredentials | None = None, + ) -> None: messages_panel = self.query_one("#messages-container", TabbedMessagesPanel) messages_panel.add_system_message(f"Connecting to {agent_url}...") try: - credentials = messages_panel.get_auth_credentials() agent_card = await self._connect_to_agent(agent_url, credentials) self.current_agent_card = agent_card @@ -227,8 +252,6 @@ async def _send_message(self) -> None: credentials = messages_panel.get_auth_credentials() if credentials: self._agent_service.set_credentials(credentials) - else: - self._agent_service.clear_credentials() send_result = await self._agent_service.send( message_text, diff --git a/src/a2a_handler/tui/app.tcss b/src/a2a_handler/tui/app.tcss index 6d59da7..9635d45 100644 --- a/src/a2a_handler/tui/app.tcss +++ b/src/a2a_handler/tui/app.tcss @@ -134,8 +134,25 @@ Button:hover { height: 1fr; width: 1fr; hatch: cross $secondary 30%; - content-align: center middle; - color: $text-muted; + align: center middle; + padding: 1 2; +} + +.saved-agents-title { + color: $text; + text-style: bold; + padding: 0 0 1 0; + width: 100%; + text-align: center; +} + +.saved-agent-btn { + width: 100%; + margin: 0 0 1 0; +} + +.saved-agent-url { + display: none; } #raw-scroll { diff --git a/src/a2a_handler/tui/components/auth.py b/src/a2a_handler/tui/components/auth.py index f7a50ca..a9626b3 100644 --- a/src/a2a_handler/tui/components/auth.py +++ b/src/a2a_handler/tui/components/auth.py @@ -11,6 +11,8 @@ AuthType, create_api_key_auth, create_bearer_auth, + create_mtls_auth, + parse_header_string, ) from a2a_handler.common import get_logger @@ -28,6 +30,7 @@ def compose(self) -> ComposeResult: yield RadioButton("None", id="auth-none", value=True) yield RadioButton("API Key", id="auth-api-key") yield RadioButton("Bearer Token", id="auth-bearer") + yield RadioButton("mTLS (Client Certificate)", id="auth-mtls") with Vertical(id="api-key-fields", classes="auth-fields hidden"): yield Label("API Key") @@ -41,13 +44,29 @@ def compose(self) -> ComposeResult: placeholder="Enter bearer token", id="bearer-token-input", password=True ) + with Vertical(id="mtls-fields", classes="auth-fields hidden"): + yield Label("Client Certificate") + yield Input(placeholder="/path/to/client.crt", id="mtls-cert-input") + yield Label("Client Private Key") + yield Input(placeholder="/path/to/client.key", id="mtls-key-input") + yield Label("CA Certificate (optional)") + yield Input(placeholder="/path/to/ca.crt", id="mtls-ca-input") + + yield Label("Custom Headers (optional, semicolon-separated)") + yield Input( + placeholder="x-user-id: me@mydomain.com; x-org: acme", + id="custom-headers-input", + ) + def on_radio_set_changed(self, event: RadioSet.Changed) -> None: """Handle auth type selection changes.""" api_key_fields = self.query_one("#api-key-fields", Vertical) bearer_fields = self.query_one("#bearer-fields", Vertical) + mtls_fields = self.query_one("#mtls-fields", Vertical) api_key_fields.add_class("hidden") bearer_fields.add_class("hidden") + mtls_fields.add_class("hidden") if event.pressed.id == "auth-api-key": api_key_fields.remove_class("hidden") @@ -55,9 +74,28 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: elif event.pressed.id == "auth-bearer": bearer_fields.remove_class("hidden") logger.debug("Auth type changed to Bearer Token") + elif event.pressed.id == "auth-mtls": + mtls_fields.remove_class("hidden") + logger.debug("Auth type changed to mTLS") else: logger.debug("Auth type changed to None") + def _parse_custom_headers(self) -> dict[str, str] | None: + raw = self.query_one("#custom-headers-input", Input).value.strip() + if not raw: + return None + headers: dict[str, str] = {} + for line in raw.split(";"): + line = line.strip() + if not line: + continue + try: + name, value = parse_header_string(line) + headers[name] = value + except ValueError: + logger.warning("Skipping invalid header: %s", line) + return headers or None + def get_credentials(self) -> AuthCredentials | None: """Get the configured authentication credentials. @@ -66,24 +104,43 @@ def get_credentials(self) -> AuthCredentials | None: """ radio_set = self.query_one("#auth-type-selector", RadioSet) pressed = radio_set.pressed_button + custom_headers = self._parse_custom_headers() - if pressed is None or pressed.id == "auth-none": - return None + credentials: AuthCredentials | None = None - if pressed.id == "auth-api-key": + if pressed is not None and pressed.id == "auth-api-key": api_key = self.query_one("#api-key-input", Input).value header_name = ( self.query_one("#api-key-header-input", Input).value or "X-API-Key" ) if api_key: - return create_api_key_auth(api_key, header_name=header_name) + credentials = create_api_key_auth(api_key, header_name=header_name) - elif pressed.id == "auth-bearer": + elif pressed is not None and pressed.id == "auth-bearer": token = self.query_one("#bearer-token-input", Input).value if token: - return create_bearer_auth(token) - - return None + credentials = create_bearer_auth(token) + + elif pressed is not None and pressed.id == "auth-mtls": + cert_path = self.query_one("#mtls-cert-input", Input).value + key_path = self.query_one("#mtls-key-input", Input).value + ca_cert_path = self.query_one("#mtls-ca-input", Input).value or None + if cert_path and key_path: + try: + credentials = create_mtls_auth(cert_path, key_path, ca_cert_path) + except FileNotFoundError: + logger.warning("mTLS certificate file not found") + + if custom_headers: + if credentials is None: + credentials = AuthCredentials( + auth_type=AuthType.BEARER, + custom_headers=custom_headers, + ) + else: + credentials.custom_headers = custom_headers + + return credentials def get_auth_type(self) -> AuthType | None: """Get the currently selected auth type.""" @@ -96,6 +153,8 @@ def get_auth_type(self) -> AuthType | None: return AuthType.API_KEY elif pressed.id == "auth-bearer": return AuthType.BEARER + elif pressed.id == "auth-mtls": + return AuthType.MTLS return None def set_bearer_token(self, token: str) -> None: diff --git a/src/a2a_handler/tui/components/card.py b/src/a2a_handler/tui/components/card.py index ad54c46..e407b6d 100644 --- a/src/a2a_handler/tui/components/card.py +++ b/src/a2a_handler/tui/components/card.py @@ -7,10 +7,13 @@ from rich.syntax import Syntax from textual.app import ComposeResult from textual.binding import Binding -from textual.containers import Container, VerticalScroll -from textual.widgets import Static +from textual.containers import Container, Vertical, VerticalScroll +from textual.message import Message +from textual.widgets import Button, Static +from a2a_handler.auth import AuthType from a2a_handler.common import get_logger +from a2a_handler.session import get_session_store logger = get_logger(__name__) @@ -22,17 +25,28 @@ "dracula": "dracula", } +AUTH_TYPE_LABELS: dict[AuthType, str] = { + AuthType.BEARER: "Bearer", + AuthType.API_KEY: "API Key", + AuthType.MTLS: "mTLS", +} + class AgentCardPanel(Container): """Panel displaying agent card information with tabs.""" + class AgentSelected(Message): + def __init__(self, agent_url: str) -> None: + super().__init__() + self.agent_url = agent_url + BINDINGS = [ - Binding("j", "scroll_down", "↓ Scroll", show=True, key_display="j/↓"), - Binding("k", "scroll_up", "↑ Scroll", show=True, key_display="k/↑"), + Binding("j", "scroll_down", "\u2193 Scroll", show=True, key_display="j/\u2193"), + Binding("k", "scroll_up", "\u2191 Scroll", show=True, key_display="k/\u2191"), Binding("down", "scroll_down", "Scroll Down", show=False), Binding("up", "scroll_up", "Scroll Up", show=False), - Binding("ctrl+d", "scroll_half_down", "½ Page ↓", show=True), - Binding("ctrl+u", "scroll_half_up", "½ Page ↑", show=True), + Binding("ctrl+d", "scroll_half_down", "\u00bd Page \u2193", show=True), + Binding("ctrl+u", "scroll_half_up", "\u00bd Page \u2191", show=True), ] can_focus = True @@ -46,9 +60,10 @@ def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | No def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._current_agent_card: AgentCard | None = None + self._button_url_map: dict[str, str] = {} def compose(self) -> ComposeResult: - yield Static("Connect to an A2A server", id="placeholder") + yield Vertical(id="placeholder") yield VerticalScroll( Static("", id="agent-raw"), id="raw-scroll", @@ -57,19 +72,60 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: for widget in self.query("VerticalScroll"): widget.can_focus = False + self._populate_saved_agents() self._show_placeholder() logger.debug("Agent card panel mounted") + def _populate_saved_agents(self) -> None: + placeholder = self.query_one("#placeholder", Vertical) + placeholder.remove_children() + self._button_url_map.clear() + + store = get_session_store() + sessions = store.list_all() + agents_with_creds = [s for s in sessions if s.credentials is not None] + + if not agents_with_creds: + placeholder.mount(Static("Connect to an A2A server")) + return + + placeholder.mount(Static("Saved Agents", classes="saved-agents-title")) + for idx, session in enumerate(agents_with_creds): + auth_label = "" + if session.credentials: + auth_label = AUTH_TYPE_LABELS.get(session.credentials.auth_type, "") + if session.credentials.custom_headers: + header_names = ", ".join(session.credentials.custom_headers.keys()) + if auth_label: + auth_label = f"{auth_label} + {header_names}" + else: + auth_label = header_names + + label = session.agent_url + if auth_label: + label = f"{session.agent_url} [{auth_label}]" + + button_id = f"saved-agent-{idx}" + self._button_url_map[button_id] = session.agent_url + placeholder.mount(Button(label, id=button_id, classes="saved-agent-btn")) + + def _on_button_pressed(self, event: Button.Pressed) -> None: + if "saved-agent-btn" not in event.button.classes: + return + agent_url = self._button_url_map.get(event.button.id or "") + if agent_url: + self.post_message(self.AgentSelected(agent_url)) + def _show_placeholder(self) -> None: """Show the hatch placeholder, hide the raw scroll content.""" - placeholder = self.query_one("#placeholder", Static) + placeholder = self.query_one("#placeholder", Vertical) raw_scroll = self.query_one("#raw-scroll", VerticalScroll) placeholder.display = True raw_scroll.display = False def _show_content(self) -> None: """Show the raw scroll content, hide the placeholder.""" - placeholder = self.query_one("#placeholder", Static) + placeholder = self.query_one("#placeholder", Vertical) raw_scroll = self.query_one("#raw-scroll", VerticalScroll) placeholder.display = False raw_scroll.display = True @@ -89,6 +145,7 @@ def update_card(self, agent_card: AgentCard | None) -> None: if agent_card is None: logger.debug("Clearing agent card display") raw_view_widget.update("") + self._populate_saved_agents() self._show_placeholder() else: logger.info("Displaying agent card for: %s", agent_card.name) diff --git a/src/a2a_handler/tui/components/contact.py b/src/a2a_handler/tui/components/contact.py index 20c94f6..190e639 100644 --- a/src/a2a_handler/tui/components/contact.py +++ b/src/a2a_handler/tui/components/contact.py @@ -166,3 +166,8 @@ def get_url(self) -> str: """Get the current agent URL from the input field.""" url_input = self.query_one("#agent-url", Input) return url_input.value.strip() + + def set_url(self, url: str) -> None: + """Set the agent URL input field.""" + url_input = self.query_one("#agent-url", Input) + url_input.value = url diff --git a/tests/test_auth.py b/tests/test_auth.py index af7b44b..932cbcb 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,10 +1,16 @@ """Tests for authentication module.""" +import tempfile + +import pytest + from a2a_handler.auth import ( AuthCredentials, AuthType, create_api_key_auth, create_bearer_auth, + create_mtls_auth, + parse_header_string, ) @@ -77,3 +83,169 @@ def test_create_api_key_auth_custom(self) -> None: """create_api_key_auth with custom header.""" creds = create_api_key_auth("my-key", header_name="Authorization") assert creds.header_name == "Authorization" + + +class TestMTLSAuth: + def test_mtls_to_headers_returns_empty(self) -> None: + creds = AuthCredentials( + auth_type=AuthType.MTLS, + cert_path="/tmp/cert.pem", + key_path="/tmp/key.pem", + ) + assert creds.to_headers() == {} + + def test_mtls_to_dict_and_from_dict(self) -> None: + original = AuthCredentials( + auth_type=AuthType.MTLS, + cert_path="/tmp/cert.pem", + key_path="/tmp/key.pem", + ca_cert_path="/tmp/ca.pem", + ) + data = original.to_dict() + restored = AuthCredentials.from_dict(data) + + assert restored.auth_type == AuthType.MTLS + assert restored.cert_path == "/tmp/cert.pem" + assert restored.key_path == "/tmp/key.pem" + assert restored.ca_cert_path == "/tmp/ca.pem" + + def test_mtls_to_dict_without_ca_cert(self) -> None: + creds = AuthCredentials( + auth_type=AuthType.MTLS, + cert_path="/tmp/cert.pem", + key_path="/tmp/key.pem", + ) + data = creds.to_dict() + assert "ca_cert_path" not in data + + def test_create_mtls_auth_validates_cert_exists(self) -> None: + with pytest.raises(FileNotFoundError, match="Client certificate not found"): + create_mtls_auth("/nonexistent/cert.pem", "/nonexistent/key.pem") + + def test_create_mtls_auth_validates_key_exists(self) -> None: + with tempfile.NamedTemporaryFile(suffix=".pem") as cert_file: + with pytest.raises(FileNotFoundError, match="Client private key not found"): + create_mtls_auth(cert_file.name, "/nonexistent/key.pem") + + def test_create_mtls_auth_validates_ca_cert_exists(self) -> None: + with ( + tempfile.NamedTemporaryFile(suffix=".pem") as cert_file, + tempfile.NamedTemporaryFile(suffix=".pem") as key_file, + ): + with pytest.raises(FileNotFoundError, match="CA certificate not found"): + create_mtls_auth(cert_file.name, key_file.name, "/nonexistent/ca.pem") + + def test_create_mtls_auth_success(self) -> None: + with ( + tempfile.NamedTemporaryFile(suffix=".pem") as cert_file, + tempfile.NamedTemporaryFile(suffix=".pem") as key_file, + ): + creds = create_mtls_auth(cert_file.name, key_file.name) + assert creds.auth_type == AuthType.MTLS + assert creds.cert_path == cert_file.name + assert creds.key_path == key_file.name + assert creds.ca_cert_path is None + + def test_create_mtls_auth_with_ca_cert(self) -> None: + with ( + tempfile.NamedTemporaryFile(suffix=".pem") as cert_file, + tempfile.NamedTemporaryFile(suffix=".pem") as key_file, + tempfile.NamedTemporaryFile(suffix=".pem") as ca_file, + ): + creds = create_mtls_auth(cert_file.name, key_file.name, ca_file.name) + assert creds.ca_cert_path == ca_file.name + + def test_build_ssl_context_rejects_non_mtls(self) -> None: + creds = AuthCredentials(auth_type=AuthType.BEARER, value="token") + with pytest.raises(ValueError, match="mTLS"): + creds.build_ssl_context() + + def test_build_ssl_context_requires_paths(self) -> None: + creds = AuthCredentials(auth_type=AuthType.MTLS) + with pytest.raises(ValueError, match="cert_path and key_path"): + creds.build_ssl_context() + + +class TestCustomHeaders: + def test_custom_headers_merged_with_bearer(self) -> None: + creds = AuthCredentials( + auth_type=AuthType.BEARER, + value="token", + custom_headers={"x-user-id": "me@example.com"}, + ) + headers = creds.to_headers() + assert headers["Authorization"] == "Bearer token" + assert headers["x-user-id"] == "me@example.com" + + def test_custom_headers_merged_with_api_key(self) -> None: + creds = AuthCredentials( + auth_type=AuthType.API_KEY, + value="key123", + custom_headers={"x-org": "acme"}, + ) + headers = creds.to_headers() + assert headers["X-API-Key"] == "key123" + assert headers["x-org"] == "acme" + + def test_custom_headers_with_mtls(self) -> None: + creds = AuthCredentials( + auth_type=AuthType.MTLS, + cert_path="/tmp/cert.pem", + key_path="/tmp/key.pem", + custom_headers={"x-user-id": "me@example.com"}, + ) + headers = creds.to_headers() + assert headers == {"x-user-id": "me@example.com"} + + def test_custom_headers_only(self) -> None: + creds = AuthCredentials( + auth_type=AuthType.BEARER, + custom_headers={"x-user-id": "me@example.com", "x-org": "acme"}, + ) + headers = creds.to_headers() + assert "Authorization" not in headers + assert headers == {"x-user-id": "me@example.com", "x-org": "acme"} + + def test_custom_headers_roundtrip_serialization(self) -> None: + original = AuthCredentials( + auth_type=AuthType.BEARER, + value="token", + custom_headers={"x-user-id": "me@example.com", "x-org": "acme"}, + ) + data = original.to_dict() + restored = AuthCredentials.from_dict(data) + + assert restored.custom_headers == { + "x-user-id": "me@example.com", + "x-org": "acme", + } + + def test_no_custom_headers_not_in_dict(self) -> None: + creds = AuthCredentials(auth_type=AuthType.BEARER, value="token") + data = creds.to_dict() + assert "custom_headers" not in data + + +class TestParseHeaderString: + def test_parse_valid_header(self) -> None: + name, value = parse_header_string("x-user-id: me@example.com") + assert name == "x-user-id" + assert value == "me@example.com" + + def test_parse_header_with_extra_colons(self) -> None: + name, value = parse_header_string("x-data: value:with:colons") + assert name == "x-data" + assert value == "value:with:colons" + + def test_parse_header_strips_whitespace(self) -> None: + name, value = parse_header_string(" x-user-id : me@example.com ") + assert name == "x-user-id" + assert value == "me@example.com" + + def test_parse_header_no_colon_raises(self) -> None: + with pytest.raises(ValueError, match="Invalid header format"): + parse_header_string("no-colon-here") + + def test_parse_header_empty_name_raises(self) -> None: + with pytest.raises(ValueError, match="Empty header name"): + parse_header_string(": value") diff --git a/tests/test_cli_auth.py b/tests/test_cli_auth.py index 508fd46..bb1052a 100644 --- a/tests/test_cli_auth.py +++ b/tests/test_cli_auth.py @@ -8,7 +8,12 @@ from click.testing import CliRunner from a2a_handler.cli.auth import auth -from a2a_handler.auth import AuthType, create_bearer_auth, create_api_key_auth +from a2a_handler.auth import ( + AuthCredentials, + AuthType, + create_bearer_auth, + create_api_key_auth, +) from a2a_handler.session import SessionStore @@ -95,7 +100,7 @@ def test_set_both_bearer_and_api_key_fails(self, runner): ) assert result.exit_code == 1 - assert "not both" in result.output.lower() or "either" in result.output.lower() + assert "only one auth method" in result.output.lower() def test_set_neither_bearer_nor_api_key_fails(self, runner): """Test that providing neither bearer nor API key fails.""" @@ -184,3 +189,222 @@ def test_clear_rejects_invalid_agent_url(self, runner): assert result.exit_code == 1 assert "agent_url must be a valid http(s) URL" in result.output + + +class TestAuthSetMTLS: + @pytest.fixture + def runner(self): + return CliRunner() + + def test_set_mtls_credentials(self, runner): + with ( + tempfile.NamedTemporaryFile(suffix=".pem") as cert_file, + tempfile.NamedTemporaryFile(suffix=".pem") as key_file, + ): + with patch("a2a_handler.cli.auth.set_credentials") as mock_set: + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--cert", + cert_file.name, + "--key", + key_file.name, + ], + ) + + assert result.exit_code == 0 + assert "mTLS" in result.output + mock_set.assert_called_once() + call_args = mock_set.call_args + assert call_args[0][1].auth_type == AuthType.MTLS + assert call_args[0][1].cert_path == cert_file.name + assert call_args[0][1].key_path == key_file.name + + def test_set_mtls_with_ca_cert(self, runner): + with ( + tempfile.NamedTemporaryFile(suffix=".pem") as cert_file, + tempfile.NamedTemporaryFile(suffix=".pem") as key_file, + tempfile.NamedTemporaryFile(suffix=".pem") as ca_file, + ): + with patch("a2a_handler.cli.auth.set_credentials") as mock_set: + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--cert", + cert_file.name, + "--key", + key_file.name, + "--ca-cert", + ca_file.name, + ], + ) + + assert result.exit_code == 0 + call_args = mock_set.call_args + assert call_args[0][1].ca_cert_path == ca_file.name + + def test_set_mtls_missing_key_fails(self, runner): + with tempfile.NamedTemporaryFile(suffix=".pem") as cert_file: + result = runner.invoke( + auth, + ["set", "http://localhost:8000", "--cert", cert_file.name], + ) + assert result.exit_code == 1 + + def test_set_mtls_missing_cert_fails(self, runner): + with tempfile.NamedTemporaryFile(suffix=".pem") as key_file: + result = runner.invoke( + auth, + ["set", "http://localhost:8000", "--key", key_file.name], + ) + assert result.exit_code == 1 + + def test_set_mtls_and_bearer_fails(self, runner): + with ( + tempfile.NamedTemporaryFile(suffix=".pem") as cert_file, + tempfile.NamedTemporaryFile(suffix=".pem") as key_file, + ): + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--cert", + cert_file.name, + "--key", + key_file.name, + "--bearer", + "token", + ], + ) + assert result.exit_code == 1 + + def test_set_mtls_nonexistent_cert_fails(self, runner): + with tempfile.NamedTemporaryFile(suffix=".pem") as key_file: + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--cert", + "/nonexistent/cert.pem", + "--key", + key_file.name, + ], + ) + assert result.exit_code == 1 + + def test_show_mtls_credentials(self, runner): + mock_creds = AuthCredentials( + auth_type=AuthType.MTLS, + cert_path="/path/to/cert.pem", + key_path="/path/to/key.pem", + ca_cert_path="/path/to/ca.pem", + ) + + with patch("a2a_handler.cli.auth.get_credentials", return_value=mock_creds): + result = runner.invoke(auth, ["show", "http://localhost:8000"]) + + assert result.exit_code == 0 + assert "mtls" in result.output.lower() + assert "/path/to/cert.pem" in result.output + assert "/path/to/key.pem" in result.output + assert "/path/to/ca.pem" in result.output + + +class TestAuthSetCustomHeaders: + @pytest.fixture + def runner(self): + return CliRunner() + + def test_set_bearer_with_custom_headers(self, runner): + with patch("a2a_handler.cli.auth.set_credentials") as mock_set: + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--bearer", + "my-token", + "--header", + "x-user-id: me@example.com", + ], + ) + + assert result.exit_code == 0 + assert "x-user-id" in result.output + call_args = mock_set.call_args + creds = call_args[0][1] + assert creds.auth_type == AuthType.BEARER + assert creds.custom_headers == {"x-user-id": "me@example.com"} + + def test_set_multiple_custom_headers(self, runner): + with patch("a2a_handler.cli.auth.set_credentials") as mock_set: + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--bearer", + "my-token", + "-H", + "x-user-id: me@example.com", + "-H", + "x-org: acme", + ], + ) + + assert result.exit_code == 0 + creds = mock_set.call_args[0][1] + assert creds.custom_headers == { + "x-user-id": "me@example.com", + "x-org": "acme", + } + + def test_set_headers_only(self, runner): + with patch("a2a_handler.cli.auth.set_credentials") as mock_set: + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--header", + "x-user-id: me@example.com", + ], + ) + + assert result.exit_code == 0 + creds = mock_set.call_args[0][1] + assert creds.custom_headers == {"x-user-id": "me@example.com"} + + def test_set_invalid_header_format_fails(self, runner): + result = runner.invoke( + auth, + [ + "set", + "http://localhost:8000", + "--bearer", + "token", + "--header", + "no-colon-here", + ], + ) + assert result.exit_code == 1 + + def test_show_custom_headers(self, runner): + mock_creds = AuthCredentials( + auth_type=AuthType.BEARER, + value="my-token-value-here", + custom_headers={"x-user-id": "me@example.com"}, + ) + + with patch("a2a_handler.cli.auth.get_credentials", return_value=mock_creds): + result = runner.invoke(auth, ["show", "http://localhost:8000"]) + + assert result.exit_code == 0 + assert "x-user-id" in result.output + assert "me@example.com" in result.output diff --git a/tests/test_cli_card.py b/tests/test_cli_card.py index 90c7f0d..4cc7ee8 100644 --- a/tests/test_cli_card.py +++ b/tests/test_cli_card.py @@ -90,8 +90,8 @@ def test_card_get_connection_error(self, runner): assert result.exit_code == 1 - def test_card_get_authenticated_uses_saved_credentials(self, runner): - """Test authenticated card get passes stored credentials to service.""" + def test_card_get_uses_saved_credentials(self, runner): + """Test card get passes stored credentials to service.""" mock_card = _make_agent_card() credentials = create_bearer_auth("test-token") @@ -110,7 +110,7 @@ def test_card_get_authenticated_uses_saved_credentials(self, runner): mock_service.get_card.return_value = mock_card mock_service_cls.return_value = mock_service - result = runner.invoke(card, ["get", "http://localhost:8000", "-a"]) + result = runner.invoke(card, ["get", "http://localhost:8000"]) assert result.exit_code == 0 mock_get_credentials.assert_called_once_with("http://localhost:8000") diff --git a/tests/test_session.py b/tests/test_session.py index 7319f66..35029fa 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,6 +3,7 @@ import tempfile from pathlib import Path +from a2a_handler.auth import AuthCredentials, AuthType from a2a_handler.session import AgentSession, SessionStore @@ -183,3 +184,44 @@ def test_load_invalid_json(self): store.load() assert len(store.sessions) == 0 + + def test_save_and_load_mtls_credentials(self): + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + mtls_creds = AuthCredentials( + auth_type=AuthType.MTLS, + cert_path="/path/to/cert.pem", + key_path="/path/to/key.pem", + ca_cert_path="/path/to/ca.pem", + ) + store.set_credentials("http://localhost:8000", mtls_creds) + + new_store = SessionStore(session_directory=Path(temp_directory)) + new_store.load() + + loaded_creds = new_store.get_credentials("http://localhost:8000") + assert loaded_creds is not None + assert loaded_creds.auth_type == AuthType.MTLS + assert loaded_creds.cert_path == "/path/to/cert.pem" + assert loaded_creds.key_path == "/path/to/key.pem" + assert loaded_creds.ca_cert_path == "/path/to/ca.pem" + + def test_save_and_load_custom_headers(self): + with tempfile.TemporaryDirectory() as temp_directory: + store = SessionStore(session_directory=Path(temp_directory)) + creds = AuthCredentials( + auth_type=AuthType.BEARER, + value="token", + custom_headers={"x-user-id": "me@example.com", "x-org": "acme"}, + ) + store.set_credentials("http://localhost:8000", creds) + + new_store = SessionStore(session_directory=Path(temp_directory)) + new_store.load() + + loaded_creds = new_store.get_credentials("http://localhost:8000") + assert loaded_creds is not None + assert loaded_creds.custom_headers == { + "x-user-id": "me@example.com", + "x-org": "acme", + }