diff --git a/fasta2a/client.py b/fasta2a/client.py index cd84499..e4a2071 100644 --- a/fasta2a/client.py +++ b/fasta2a/client.py @@ -6,6 +6,7 @@ import pydantic from .schema import ( + AgentCard, GetTaskRequest, GetTaskResponse, Message, @@ -31,7 +32,23 @@ class A2AClient: """A client for the A2A protocol.""" - def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.AsyncClient | None = None) -> None: + def __init__( + self, + agent: str | AgentCard, + http_client: httpx.AsyncClient | None = None, + fetch_card: bool = False, + relative_card_path: str | None = None, + ) -> None: + self.agent_card = None + if fetch_card and isinstance(agent, str): + if relative_card_path is None: + relative_card_path = '/.well-known/agent-card.json' + agent_url = agent.rstrip('/') + relative_card_path + response = httpx.get(agent_url) + response.raise_for_status() + agent = AgentCard(**response.json()) + self.agent_card = agent + base_url = agent if isinstance(agent, str) else agent['url'] if http_client is None: self.http_client = httpx.AsyncClient(base_url=base_url) else: diff --git a/pyproject.toml b/pyproject.toml index 2af3b7d..0f4b669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,8 @@ dev = [ "pytest", "ruff", "pyright", + "pytest_asyncio", + "uvicorn" ] docs = [ "mkdocs-material[imaging]", diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..02f4877 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,105 @@ +import asyncio +import socket +import threading +from collections.abc import AsyncGenerator + +import httpx +import pytest +import pytest_asyncio +import uvicorn + +from fasta2a.applications import FastA2A +from fasta2a.broker import InMemoryBroker +from fasta2a.client import A2AClient +from fasta2a.storage import InMemoryStorage + +SERVER_HOST = '127.0.0.1' + + +def get_free_port() -> int: + """Ask OS for a free port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +async def _wait_server(url: str, retries: int = 100, delay: float = 0.1): + """Wait until the server responds (any response, even 404).""" + async with httpx.AsyncClient() as client: + for _ in range(retries): + try: + await client.get(url) + return + except httpx.RequestError: + await asyncio.sleep(delay) + raise RuntimeError(f'Server at {url} did not start in time') + + +def _start_server_in_thread(app: FastA2A, host: str, port: int): + """Run uvicorn server in a background thread.""" + + def _run_uvicorn(): + uvicorn.run(app, host=host, port=port, log_level='error') + + thread = threading.Thread(target=_run_uvicorn, daemon=True) + thread.start() + + +@pytest_asyncio.fixture(scope='function') +async def run_server(request: pytest.FixtureRequest) -> AsyncGenerator[str, None]: + """ + Generic fixture to run a FastA2A server. + + Accepts optional parameters via `request.param`: + - name: agent name + - description: agent description + """ + params = getattr(request, 'param', {}) + port = get_free_port() + url = f'http://{SERVER_HOST}:{port}' + + app = FastA2A( + storage=InMemoryStorage(), + broker=InMemoryBroker(), + url=url, + name=params.get('name'), + description=params.get('description'), + ) + + _start_server_in_thread(app, SERVER_HOST, port) + await _wait_server(url) + yield url + + +# ---------------------- +# Tests +# ---------------------- + + +@pytest.mark.asyncio +async def test_client_basic(run_server: str): + client = A2AClient(agent=run_server) + assert str(client.http_client.base_url) == run_server + + +@pytest.mark.asyncio +async def test_client_fetch_card(run_server: str): + client = A2AClient(agent=run_server, fetch_card=True) + assert hasattr(client, 'agent_card') and client.agent_card is not None + assert client.http_client.base_url == run_server + + +# Parameterize the fixture for tests needing a named agent +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'run_server', + [{'name': 'Test Agent', 'description': 'A test agent for unit tests.'}], + indirect=True, +) +async def test_client_check_agent_card(run_server: str): + client = A2AClient(agent=run_server, fetch_card=True) + assert client.http_client.base_url == run_server + assert hasattr(client, 'agent_card') and client.agent_card is not None + card = getattr(client, 'agent_card', {}) + assert card.get('name') == 'Test Agent' + assert card.get('description') == 'A test agent for unit tests.'