Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions src/a2a_handler/auth.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""Authentication support for A2A protocol.

Handles credential storage and HTTP authentication header generation.
Currently supports API key and HTTP bearer authentication schemes.
Supports API key, HTTP bearer, and mTLS (mutual TLS) authentication schemes.
"""

import ssl
from dataclasses import dataclass
from enum import Enum
from pathlib import Path


class AuthType(str, Enum):
"""Supported authentication types."""

API_KEY = "api_key"
BEARER = "bearer"
MTLS = "mtls"


@dataclass
Expand All @@ -23,40 +26,91 @@ class AuthCredentials:
"""

auth_type: AuthType
value: str
value: str = ""
header_name: str | None = None # For API key: custom header name
cert_path: str | None = None # For mTLS: client certificate path
key_path: str | None = None # For mTLS: client private key path
ca_cert_path: str | None = None # For mTLS: CA certificate path
custom_headers: dict[str, str] | None = None # Additional headers for any auth type

def to_headers(self) -> dict[str, str]:
"""Generate HTTP headers for this credential.

Returns:
Dictionary of headers to include in requests
"""
if self.auth_type == AuthType.BEARER:
return {"Authorization": f"Bearer {self.value}"}
headers: dict[str, str] = {}
if self.auth_type == AuthType.BEARER and self.value:
headers["Authorization"] = f"Bearer {self.value}"
elif self.auth_type == AuthType.API_KEY:
header = self.header_name or "X-API-Key"
return {header: self.value}
return {}

def to_dict(self) -> dict[str, str | None]:
headers[header] = self.value
if self.custom_headers:
headers.update(self.custom_headers)
return headers

def build_ssl_context(self) -> ssl.SSLContext:
"""Build an SSL context for mTLS client certificate authentication."""
if self.auth_type != AuthType.MTLS:
raise ValueError("SSL context can only be built for mTLS credentials")
if not self.cert_path or not self.key_path:
raise ValueError("cert_path and key_path are required for mTLS")

if self.ca_cert_path:
ctx = ssl.create_default_context(cafile=self.ca_cert_path)
else:
ctx = ssl.create_default_context()

ctx.load_cert_chain(certfile=self.cert_path, keyfile=self.key_path)
return ctx

def to_dict(self) -> dict:
"""Serialize credentials for storage."""
return {
data: dict = {
"auth_type": self.auth_type.value,
"value": self.value,
"header_name": self.header_name,
}
if self.cert_path:
data["cert_path"] = self.cert_path
if self.key_path:
data["key_path"] = self.key_path
if self.ca_cert_path:
data["ca_cert_path"] = self.ca_cert_path
if self.custom_headers:
data["custom_headers"] = self.custom_headers
return data

@classmethod
def from_dict(cls, data: dict[str, str | None]) -> "AuthCredentials":
def from_dict(cls, data: dict) -> "AuthCredentials":
"""Deserialize credentials from storage."""
custom_headers_raw = data.get("custom_headers")
custom_headers = (
dict(custom_headers_raw) if isinstance(custom_headers_raw, dict) else None
)
return cls(
auth_type=AuthType(data["auth_type"]),
value=data.get("value") or "",
header_name=data.get("header_name"),
cert_path=data.get("cert_path"),
key_path=data.get("key_path"),
ca_cert_path=data.get("ca_cert_path"),
custom_headers=custom_headers,
)


def parse_header_string(header: str) -> tuple[str, str]:
"""Parse a 'Name: Value' header string into a (name, value) tuple."""
if ":" not in header:
raise ValueError(f"Invalid header format (expected 'Name: Value'): {header}")
name, _, value = header.partition(":")
name = name.strip()
value = value.strip()
if not name:
raise ValueError(f"Empty header name in: {header}")
return name, value


def create_bearer_auth(token: str) -> AuthCredentials:
"""Create bearer token authentication."""
return AuthCredentials(auth_type=AuthType.BEARER, value=token)
Expand All @@ -77,3 +131,24 @@ def create_api_key_auth(
value=key,
header_name=header_name,
)


def create_mtls_auth(
cert_path: str,
key_path: str,
ca_cert_path: str | None = None,
) -> AuthCredentials:
"""Create mTLS (mutual TLS) client certificate authentication."""
if not Path(cert_path).is_file():
raise FileNotFoundError(f"Client certificate not found: {cert_path}")
if not Path(key_path).is_file():
raise FileNotFoundError(f"Client private key not found: {key_path}")
if ca_cert_path and not Path(ca_cert_path).is_file():
raise FileNotFoundError(f"CA certificate not found: {ca_cert_path}")

return AuthCredentials(
auth_type=AuthType.MTLS,
cert_path=cert_path,
key_path=key_path,
ca_cert_path=ca_cert_path,
)
11 changes: 10 additions & 1 deletion src/a2a_handler/cli/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,24 @@
A2AClientTimeoutError,
)

from a2a_handler.auth import AuthCredentials, AuthType
from a2a_handler.common import Output, get_logger
from a2a_handler.common.input_validation import InputValidationError

TIMEOUT = 120
log = get_logger(__name__)


def build_http_client(timeout: int = TIMEOUT) -> httpx.AsyncClient:
def build_http_client(
timeout: int = TIMEOUT,
credentials: AuthCredentials | None = None,
) -> httpx.AsyncClient:
"""Build an HTTP client with the specified timeout."""
if credentials and credentials.auth_type == AuthType.MTLS:
return httpx.AsyncClient(
timeout=timeout,
verify=credentials.build_ssl_context(),
)
return httpx.AsyncClient(timeout=timeout)


Expand Down
113 changes: 94 additions & 19 deletions src/a2a_handler/cli/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import rich_click as click

from a2a_handler.auth import AuthType, create_api_key_auth, create_bearer_auth
from a2a_handler.auth import (
AuthType,
create_api_key_auth,
create_bearer_auth,
create_mtls_auth,
parse_header_string,
)
from a2a_handler.common import Output
from a2a_handler.common.input_validation import (
InputValidationError,
Expand All @@ -31,15 +37,34 @@ def auth() -> None:
default="X-API-Key",
help="Header name for API key (default: X-API-Key)",
)
@click.option("--cert", "cert_path", help="Client certificate path for mTLS (PEM)")
@click.option("--key", "key_path", help="Client private key path for mTLS (PEM)")
@click.option(
"--ca-cert",
"ca_cert_path",
help="CA certificate path for mTLS server verification (PEM)",
)
@click.option(
"--header",
"-H",
"headers",
multiple=True,
help="Custom header (repeatable, format: 'Name: Value')",
)
def auth_set(
agent_url: str,
bearer_token: Optional[str],
api_key: Optional[str],
api_key_header: str,
cert_path: Optional[str],
key_path: Optional[str],
ca_cert_path: Optional[str],
headers: tuple[str, ...],
) -> None:
"""Set authentication credentials for an agent.

Provide either --bearer or --api-key (not both).
Provide --bearer, --api-key, or --cert/--key for mTLS. Custom headers
can be added to any auth method with --header/-H.
"""
output = Output()
try:
Expand All @@ -53,24 +78,63 @@ def auth_set(
handle_validation_error(error, output)
raise click.Abort() from error

if bearer_token and api_key:
output.error("Provide either --bearer or --api-key, not both")
custom_headers: dict[str, str] | None = None
if headers:
custom_headers = {}
for h in headers:
try:
name, value = parse_header_string(h)
reject_control_chars(name, "header name")
reject_control_chars(value, "header value")
custom_headers[name] = value
except (ValueError, InputValidationError) as e:
output.error(str(e))
raise click.Abort() from e

has_mtls = cert_path or key_path
method_count = sum(bool(x) for x in [bearer_token, api_key, has_mtls])

if method_count > 1:
output.error(
"Provide only one auth method: --bearer, --api-key, or --cert/--key"
)
raise click.Abort()

if not bearer_token and not api_key:
output.error("Provide --bearer or --api-key")
if method_count == 0 and not custom_headers:
output.error("Provide --bearer, --api-key, --cert/--key, or --header")
raise click.Abort()

if bearer_token:
if has_mtls:
if not cert_path or not key_path:
output.error("mTLS requires both --cert and --key")
raise click.Abort()
try:
credentials = create_mtls_auth(cert_path, key_path, ca_cert_path)
except FileNotFoundError as e:
output.error(str(e))
raise click.Abort() from e
auth_type_display = "mTLS client certificate"
elif bearer_token:
credentials = create_bearer_auth(bearer_token)
auth_type_display = "Bearer token"
else:
credentials = create_api_key_auth(api_key or "", header_name=api_key_header)
elif api_key:
credentials = create_api_key_auth(api_key, header_name=api_key_header)
auth_type_display = f"API key (header: {api_key_header})"
else:
from a2a_handler.auth import AuthCredentials

credentials = AuthCredentials(auth_type=AuthType.BEARER)
auth_type_display = "Custom headers only"

credentials.custom_headers = custom_headers

set_credentials(agent_url, credentials)

output.success(f"Set {auth_type_display} for {agent_url}")
parts = [auth_type_display]
if custom_headers:
header_names = ", ".join(custom_headers.keys())
parts.append(f"+ headers: {header_names}")
output.success(f"Set {' '.join(parts)} for {agent_url}")


@auth.command("show")
Expand All @@ -93,15 +157,26 @@ def auth_show(agent_url: str) -> None:
return

output.field("Type", credentials.auth_type.value)
masked_value = (
f"{credentials.value[:4]}...{credentials.value[-4:]}"
if len(credentials.value) > 8
else "****"
)
output.field("Value", masked_value)

if credentials.auth_type == AuthType.API_KEY:
output.field("Header", credentials.header_name or "X-API-Key")

if credentials.auth_type == AuthType.MTLS:
output.field("Certificate", credentials.cert_path or "")
output.field("Private Key", credentials.key_path or "")
if credentials.ca_cert_path:
output.field("CA Certificate", credentials.ca_cert_path)
else:
masked_value = (
f"{credentials.value[:4]}...{credentials.value[-4:]}"
if len(credentials.value) > 8
else "****"
)
output.field("Value", masked_value)

if credentials.auth_type == AuthType.API_KEY:
output.field("Header", credentials.header_name or "X-API-Key")

if credentials.custom_headers:
for name, value in credentials.custom_headers.items():
output.field(f"Header: {name}", value)


@auth.command("clear")
Expand Down
15 changes: 6 additions & 9 deletions src/a2a_handler/cli/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def card() -> None:

@card.command("get")
@click.argument("agent_url")
@click.option(
"--authenticated", "-a", is_flag=True, help="Request authenticated extended card"
)
def card_get(agent_url: str, authenticated: bool) -> None:
def card_get(agent_url: str) -> None:
"""Retrieve an agent's card."""
output = Output()
try:
Expand All @@ -46,13 +43,11 @@ def card_get(agent_url: str, authenticated: bool) -> None:
raise click.Abort() from error

log.info("Fetching agent card from %s", agent_url)
credentials = get_credentials(agent_url) if authenticated else None
if authenticated and credentials is None:
log.warning("No saved credentials found for %s", agent_url)
credentials = get_credentials(agent_url)

async def do_get() -> None:
try:
async with build_http_client() as http_client:
async with build_http_client(credentials=credentials) as http_client:
service = A2AService(http_client, agent_url, credentials=credentials)
card_data = await service.get_card()
log.info("Retrieved card for agent: %s", card_data.name)
Expand Down Expand Up @@ -92,9 +87,11 @@ def card_validate(source: str) -> None:
handle_validation_error(error, output)
raise click.Abort() from error

credentials = get_credentials(source) if is_url else None

async def do_validate() -> None:
if is_url:
async with build_http_client() as http_client:
async with build_http_client(credentials=credentials) as http_client:
result = await validate_agent_card_from_url(source, http_client)
else:
result = validate_agent_card_from_file(source)
Expand Down
Loading
Loading