diff --git a/.gitignore b/.gitignore index b3c9e26..dc20eb9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .cache __pycache__ +.coverage* \ No newline at end of file diff --git a/fasta2a/__init__.py b/fasta2a/__init__.py index 4a8b106..b3ccfc0 100644 --- a/fasta2a/__init__.py +++ b/fasta2a/__init__.py @@ -1,7 +1,15 @@ from .applications import FastA2A from .broker import Broker -from .schema import Skill +from .schema import AgentExtension, Skill, StreamEvent from .storage import Storage from .worker import Worker -__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker'] +__all__ = [ + 'AgentExtension', + 'Broker', + 'FastA2A', + 'Skill', + 'Storage', + 'StreamEvent', + 'Worker', +] diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 958530d..93046e4 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any +from sse_starlette import EventSourceResponse from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.requests import Request @@ -16,11 +17,15 @@ from .schema import ( AgentCapabilities, AgentCard, + AgentExtension, AgentProvider, Skill, + StreamMessageResponse, a2a_request_ta, a2a_response_ta, agent_card_ta, + stream_message_request_ta, + stream_message_response_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -41,7 +46,9 @@ def __init__( description: str | None = None, provider: AgentProvider | None = None, skills: list[Skill] | None = None, + extensions: list[AgentExtension] | None = None, docs_url: str | None = '/docs', + streaming: bool = True, # Starlette debug: bool = False, routes: Sequence[Route] | None = None, @@ -66,7 +73,9 @@ def __init__( self.description = description self.provider = provider self.skills = skills or [] + self.extensions = extensions or [] self.docs_url = docs_url + self.streaming = streaming # NOTE: For now, I don't think there's any reason to support any other input/output modes. self.default_input_modes = ['application/json'] self.default_output_modes = ['application/json'] @@ -90,6 +99,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def _agent_card_endpoint(self, request: Request) -> Response: if self._agent_card_json_schema is None: + capabilities = AgentCapabilities( + streaming=self.streaming, push_notifications=False, state_transition_history=False + ) + if self.extensions: + capabilities['extensions'] = self.extensions agent_card = AgentCard( name=self.name, description=self.description or 'An AI agent exposed as an A2A agent.', @@ -99,9 +113,7 @@ async def _agent_card_endpoint(self, request: Request) -> Response: skills=self.skills, default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, - capabilities=AgentCapabilities( - streaming=False, push_notifications=False, state_transition_history=False - ), + capabilities=capabilities, ) if self.provider is not None: agent_card['provider'] = self.provider @@ -129,8 +141,30 @@ async def _agent_run_endpoint(self, request: Request) -> Response: data = await request.body() a2a_request = a2a_request_ta.validate_json(data) + # Parse activated extensions from the A2A-Extensions header + extensions_header = request.headers.get('a2a-extensions', '') + activated_extensions: list[str] = ( + [uri.strip() for uri in extensions_header.split(',') if uri.strip()] if extensions_header else [] + ) + # Stash on the request state so workers / handlers can inspect them + request.state.activated_extensions = activated_extensions + if a2a_request['method'] == 'message/send': jsonrpc_response = await self.task_manager.send_message(a2a_request) + elif a2a_request['method'] == 'message/stream': + stream_request = stream_message_request_ta.validate_json(data) + + async def sse_generator(): + request_id = stream_request.get('id') + async for event in self.task_manager.stream_message(stream_request): + jsonrpc_response = StreamMessageResponse( + jsonrpc='2.0', + id=request_id, + result=event, + ) + yield stream_message_response_ta.dump_json(jsonrpc_response, by_alias=True).decode() + + return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/broker.py b/fasta2a/broker.py index c84b738..650c9e2 100644 --- a/fasta2a/broker.py +++ b/fasta2a/broker.py @@ -7,11 +7,12 @@ from typing import Annotated, Any, Generic, Literal, TypeVar import anyio +from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import Span, get_current_span, get_tracer from pydantic import Discriminator from typing_extensions import Self, TypedDict -from .schema import TaskIdParams, TaskSendParams +from .schema import StreamEvent, TaskIdParams, TaskSendParams tracer = get_tracer(__name__) @@ -37,6 +38,26 @@ async def cancel_task(self, params: TaskIdParams) -> None: """Cancel a task.""" raise NotImplementedError('send_cancel_task is not implemented yet.') + @abstractmethod + async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: + """Send a streaming event from worker to subscribers. + + This is used by workers to publish status updates, messages, and artifacts + during task execution. Events are forwarded to all active subscribers of + the given task_id. + """ + ... + + @abstractmethod + def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Subscribe to streaming events for a specific task. + + Returns an async iterator that yields events published by workers for the + given task_id. The iterator completes when a TaskStatusUpdateEvent with + final=True is received or the subscription is cancelled. + """ + ... + @abstractmethod async def __aenter__(self) -> Self: ... @@ -73,6 +94,10 @@ class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]): class InMemoryBroker(Broker): """A broker that schedules tasks in memory.""" + def __init__(self) -> None: + self._event_subscribers: dict[str, list[MemoryObjectSendStream[StreamEvent]]] = {} + self._subscriber_lock: anyio.Lock | None = None + async def __aenter__(self): self.aexit_stack = AsyncExitStack() await self.aexit_stack.__aenter__() @@ -81,6 +106,8 @@ async def __aenter__(self): await self.aexit_stack.enter_async_context(self._read_stream) await self.aexit_stack.enter_async_context(self._write_stream) + self._subscriber_lock = anyio.Lock() + return self async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): @@ -96,3 +123,65 @@ async def receive_task_operations(self) -> AsyncIterator[TaskOperation]: """Receive task operations from the broker.""" async for task_operation in self._read_stream: yield task_operation + + async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: + """Send a streaming event from worker to subscribers.""" + assert self._subscriber_lock is not None, 'Broker not initialized' + + async with self._subscriber_lock: + subscribers = self._event_subscribers.get(task_id, []) + if not subscribers: + return + + # Send event to all subscribers, removing closed streams + active_subscribers: list[MemoryObjectSendStream[StreamEvent]] = [] + for stream in subscribers: + try: + await stream.send(event) + active_subscribers.append(stream) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + # Subscriber disconnected, remove from list + pass + + # Update subscriber list with only active ones + if active_subscribers: + self._event_subscribers[task_id] = active_subscribers + elif task_id in self._event_subscribers: + # No active subscribers left, clean up + del self._event_subscribers[task_id] + + async def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Subscribe to streaming events for a specific task.""" + assert self._subscriber_lock is not None, 'Broker not initialized' + + # Create a new stream for this subscriber + send_stream, receive_stream = anyio.create_memory_object_stream[StreamEvent](max_buffer_size=100) + + # Register the subscriber + async with self._subscriber_lock: + if task_id not in self._event_subscribers: + self._event_subscribers[task_id] = [] + self._event_subscribers[task_id].append(send_stream) + + try: + async with receive_stream: + async for event in receive_stream: + yield event + + # Check if this is a final status update + if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final', False): + break + finally: + # Clean up subscription on exit + async with self._subscriber_lock: + if task_id in self._event_subscribers: + try: + self._event_subscribers[task_id].remove(send_stream) + if not self._event_subscribers[task_id]: + del self._event_subscribers[task_id] + except ValueError: + # Already removed + pass + + # Close the send stream + await send_stream.aclose() diff --git a/fasta2a/schema.py b/fasta2a/schema.py index 37ffb86..4f806d4 100644 --- a/fasta2a/schema.py +++ b/fasta2a/schema.py @@ -93,6 +93,15 @@ class AgentCapabilities(TypedDict): state_transition_history: NotRequired[bool] """Whether the agent exposes status change history for tasks.""" + extensions: NotRequired[list[AgentExtension]] + """A2A extensions supported by this agent. + + Each extension is declared as an ``AgentExtension`` object with a + unique ``uri``, optional ``description``, ``required`` flag, and + ``params`` configuration. Clients activate extensions by sending + the selected URIs in the ``A2A-Extensions`` HTTP header. + """ + @pydantic.with_config({'alias_generator': to_camel}) class HttpSecurityScheme(TypedDict): @@ -808,3 +817,9 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) stream_message_response_ta: TypeAdapter[StreamMessageResponse] = TypeAdapter(StreamMessageResponse) + +# Type for streaming events (used by broker and task manager) +StreamEvent = Union[Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent] +"""A streaming event that can be sent during message/stream requests.""" + +stream_event_ta: TypeAdapter[StreamEvent] = TypeAdapter(StreamEvent) diff --git a/fasta2a/storage.py b/fasta2a/storage.py index 58934c1..cc94d47 100644 --- a/fasta2a/storage.py +++ b/fasta2a/storage.py @@ -9,7 +9,13 @@ from typing_extensions import TypeVar -from .schema import Artifact, Message, Task, TaskState, TaskStatus +from .schema import ( + Artifact, + Message, + Task, + TaskState, + TaskStatus, +) ContextT = TypeVar('ContextT', default=Any) diff --git a/fasta2a/task_manager.py b/fasta2a/task_manager.py index 54ab709..9457ce8 100644 --- a/fasta2a/task_manager.py +++ b/fasta2a/task_manager.py @@ -60,7 +60,9 @@ from __future__ import annotations as _annotations +import asyncio import uuid +from collections.abc import AsyncGenerator from contextlib import AsyncExitStack from dataclasses import dataclass, field from typing import Any @@ -78,8 +80,8 @@ SendMessageResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, + StreamEvent, StreamMessageRequest, - StreamMessageResponse, TaskNotFoundError, TaskSendParams, ) @@ -156,9 +158,44 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: ) return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - async def stream_message(self, request: StreamMessageRequest) -> StreamMessageResponse: - """Stream messages using Server-Sent Events.""" - raise NotImplementedError('message/stream method is not implemented yet.') + async def stream_message(self, request: StreamMessageRequest) -> AsyncGenerator[StreamEvent, None]: + """Handle a streaming message request. + + This method: + 1. Creates and submits a new task + 2. Yields the initial task object + 3. Subscribes to the broker's event stream + 4. Starts task execution asynchronously + 5. Streams all events until completion + """ + # Extract parameters + params = request['params'] + message = params['message'] + context_id = message.get('context_id', str(uuid.uuid4())) + + # Create and submit the task + task = await self.storage.submit_task(context_id, message) + + # Yield the initial task + yield task + + # Prepare broker params + broker_params: TaskSendParams = {'id': task['id'], 'context_id': context_id, 'message': message} + config = params.get('configuration', {}) + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length + + metadata = params.get('metadata') + if metadata is not None: + broker_params['metadata'] = metadata + + # Start task execution in background + asyncio.create_task(self.broker.run_task(broker_params)) + + # Stream events from broker + async for event in self.broker.subscribe_to_stream(task['id']): + yield event async def set_task_push_notification( self, request: SetTaskPushNotificationRequest diff --git a/fasta2a/worker.py b/fasta2a/worker.py index bcb0172..3499c45 100644 --- a/fasta2a/worker.py +++ b/fasta2a/worker.py @@ -10,11 +10,12 @@ from opentelemetry.trace import get_tracer, use_span from typing_extensions import assert_never +from .schema import TaskArtifactUpdateEvent, TaskStatusUpdateEvent from .storage import ContextT, Storage if TYPE_CHECKING: from .broker import Broker, TaskOperation - from .schema import Artifact, Message, TaskIdParams, TaskSendParams + from .schema import Artifact, Message, TaskIdParams, TaskSendParams, TaskState tracer = get_tracer(__name__) @@ -56,6 +57,66 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: except Exception: await self.storage.update_task(task_operation['params']['id'], state='failed') + async def update_task( + self, + task_id: str, + state: TaskState, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, + ) -> None: + """Update a task's state in storage and publish streaming events to the broker. + + This is the primary method workers should use to update task state. It handles + both persisting the update and notifying any stream subscribers. + """ + task = await self.storage.update_task(task_id, state, new_artifacts, new_messages) + + final = state in ('completed', 'failed', 'canceled') + + # For non-final updates, publish status first + if not final: + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + kind='status-update', + task_id=task_id, + context_id=task['context_id'], + status=task['status'], + final=False, + ), + ) + + # Publish message events before final status so subscribers receive them + if new_messages: + for message in new_messages: + await self.broker.send_stream_event(task_id, message) + + # Publish artifact events + if new_artifacts: + for artifact in new_artifacts: + await self.broker.send_stream_event( + task_id, + TaskArtifactUpdateEvent( + kind='artifact-update', + task_id=task_id, + context_id=task['context_id'], + artifact=artifact, + ), + ) + + # For final updates, publish status last (after messages and artifacts) + if final: + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + kind='status-update', + task_id=task_id, + context_id=task['context_id'], + status=task['status'], + final=True, + ), + ) + @abstractmethod async def run_task(self, params: TaskSendParams) -> None: ... diff --git a/pyproject.toml b/pyproject.toml index 2af3b7d..ee2a70b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "pydantic>=2.10", "opentelemetry-api>=1.28.0", "eval_type_backport>=0.2.2; python_version <= '3.9'", + "sse-starlette>=2.0.0", ] [project.optional-dependencies] @@ -58,8 +59,11 @@ dev = [ "asgi-lifespan", "coverage", "httpx", + "httpx-sse", "inline-snapshot", "pytest", + "pytest-asyncio", + "pytest-mock", "ruff", "pyright", ] diff --git a/tests/test_applications.py b/tests/test_applications.py index 21692e0..1bc2798 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -38,7 +38,7 @@ async def test_agent_card(): 'defaultInputModes': ['application/json'], 'defaultOutputModes': ['application/json'], 'capabilities': { - 'streaming': False, + 'streaming': True, 'pushNotifications': False, 'stateTransitionHistory': False, }, diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..baf5eea --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,585 @@ +"""Tests for the SSE streaming feature in fasta2a. + +This module tests: +- Worker.update_task() streaming event publishing +- InMemoryBroker pub/sub for streaming +- TaskManager stream_message method +- FastA2A message/stream endpoint +""" + +from __future__ import annotations as _annotations + +import asyncio +import json +import uuid +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager + +from fasta2a import FastA2A, Worker +from fasta2a.broker import InMemoryBroker +from fasta2a.schema import ( + Artifact, + Message, + MessageSendParams, + StreamEvent, + StreamMessageRequest, + TaskIdParams, + TaskSendParams, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) +from fasta2a.storage import InMemoryStorage + +pytestmark = pytest.mark.anyio + + +# Test fixtures and helpers + + +@asynccontextmanager +async def create_test_client(app: FastA2A): + """Create a test client for the FastA2A app.""" + async with LifespanManager(app=app) as manager: + transport = httpx.ASGITransport(app=manager.app) + async with httpx.AsyncClient(transport=transport, base_url='http://testclient') as client: + yield client + + +Context = list[Message] + + +@dataclass +class EchoWorker(Worker[Context]): + """A simple worker for testing that echoes messages.""" + + response_text: str = 'Hello from test worker!' + delay: float = 0.1 + + async def run_task(self, params: TaskSendParams) -> None: + task = await self.storage.load_task(params['id']) + assert task is not None + + # Update to working state + await self.update_task(task['id'], state='working') + + # Simulate some work + await asyncio.sleep(self.delay) + + # Create response message + message = Message( + role='agent', + parts=[TextPart(text=self.response_text, kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + + # Create artifact + artifact = Artifact( + artifact_id=str(uuid.uuid4()), + name='result', + parts=[TextPart(text=self.response_text, kind='text')], + ) + + # Complete the task with message and artifact + await self.update_task( + task['id'], + state='completed', + new_messages=[message], + new_artifacts=[artifact], + ) + + async def cancel_task(self, params: TaskIdParams) -> None: + await self.update_task(params['id'], state='canceled') + + def build_message_history(self, history: list[Message]) -> list[Any]: + return history + + def build_artifacts(self, result: Any) -> list[Artifact]: + return [] + + +@asynccontextmanager +async def create_streaming_app(response_text: str = 'Hello!'): + """Create a FastA2A app with streaming enabled.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + + worker = EchoWorker( + broker=broker, + storage=storage, + response_text=response_text, + ) + + @asynccontextmanager + async def lifespan(app: FastA2A): + async with app.task_manager: + async with worker.run(): + yield + + app = FastA2A( + storage=storage, + broker=broker, + streaming=True, + lifespan=lifespan, + ) + + yield app + + +# Tests for InMemoryBroker streaming + + +async def test_broker_subscribe_and_receive_events(): + """Test that subscribers receive events sent by send_stream_event.""" + broker = InMemoryBroker() + + async with broker: + task_id = 'test-task-123' + received_events: list[StreamEvent] = [] + + # Start subscriber in background + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + # Check if event has a 'final' field (for status events) + if isinstance(event, dict) and event.get('final', False): + break + + subscriber_task = asyncio.create_task(subscriber()) + + # Give subscriber time to register + await asyncio.sleep(0.01) + + # Send events + status_event_1 = TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ) + message_event = Message( + message_id='test-msg-1', + role='agent', + parts=[TextPart(kind='text', text='Hello')], + kind='message', + ) + status_event_2 = TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='completed'), + final=True, + ) + await broker.send_stream_event(task_id, status_event_1) + await broker.send_stream_event(task_id, message_event) + await broker.send_stream_event(task_id, status_event_2) + + # Wait for subscriber to finish + await asyncio.wait_for(subscriber_task, timeout=1.0) + + assert len(received_events) == 3 + assert received_events[0]['kind'] == 'status-update' + assert received_events[1]['kind'] == 'message' # Message has both kind and role + # Only check 'final' on status update events + assert received_events[2]['kind'] == 'status-update' and received_events[2]['final'] is True + + +async def test_broker_multiple_subscribers(): + """Test that multiple subscribers receive the same events.""" + broker = InMemoryBroker() + + async with broker: + task_id = 'test-task-456' + received_1: list[StreamEvent] = [] + received_2: list[StreamEvent] = [] + + async def subscriber1(): + async for event in broker.subscribe_to_stream(task_id): + received_1.append(event) + if isinstance(event, dict) and event.get('final', False): + break + + async def subscriber2(): + async for event in broker.subscribe_to_stream(task_id): + received_2.append(event) + if isinstance(event, dict) and event.get('final', False): + break + + task1 = asyncio.create_task(subscriber1()) + task2 = asyncio.create_task(subscriber2()) + + await asyncio.sleep(0.01) + + await broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ), + ) + await broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='completed'), + final=True, + ), + ) + + await asyncio.wait_for(asyncio.gather(task1, task2), timeout=1.0) + + assert len(received_1) == 2 + assert len(received_2) == 2 + + +async def test_broker_no_subscribers_doesnt_error(): + """Test that sending events with no subscribers doesn't raise errors.""" + broker = InMemoryBroker() + + async with broker: + # Should not raise + await broker.send_stream_event( + 'nonexistent-task', + TaskStatusUpdateEvent( + task_id='nonexistent-task', + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ), + ) + + +# Tests for Worker.update_task() streaming events + + +async def test_worker_publishes_status_update_on_working(): + """Test that updating to 'working' publishes a status-update event.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + + async with broker: + # Create a task + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events: list[StreamEvent] = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + # Stop after first event for this test + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + # Update to working via worker + await worker.update_task(task_id, state='working') + + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received_events) == 1 + assert received_events[0]['kind'] == 'status-update' + assert received_events[0]['status']['state'] == 'working' + assert received_events[0]['final'] is False + + +async def test_worker_publishes_message_before_final_status(): + """Test that messages are published before the final status update.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + + async with broker: + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events: list[StreamEvent] = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if isinstance(event, dict) and event.get('final', False): + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + # Complete with a new message via worker + agent_message = Message( + role='agent', + parts=[TextPart(text='Hi there!', kind='text')], + kind='message', + message_id='msg-2', + ) + await worker.update_task(task_id, state='completed', new_messages=[agent_message]) + + await asyncio.wait_for(sub_task, timeout=1.0) + + # Should have: message, then final status + assert len(received_events) == 2 + assert received_events[0]['kind'] == 'message' + assert received_events[0]['role'] == 'agent' + assert received_events[1]['kind'] == 'status-update' + assert received_events[1]['final'] is True + + +async def test_worker_publishes_artifacts(): + """Test that artifacts are published.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + + async with broker: + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events: list[StreamEvent] = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if isinstance(event, dict) and event.get('final', False): + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + artifact = Artifact( + artifact_id='art-1', + name='result', + parts=[TextPart(text='Result data', kind='text')], + ) + await worker.update_task(task_id, state='completed', new_artifacts=[artifact]) + + await asyncio.wait_for(sub_task, timeout=1.0) + + # Should have: artifact, then final status + assert len(received_events) == 2 + assert received_events[0]['kind'] == 'artifact-update' + artifact = received_events[0]['artifact'] + assert artifact.get('name') == 'result' # Use .get() for NotRequired fields + assert received_events[1]['kind'] == 'status-update' + assert received_events[1]['final'] is True + + +# Tests for FastA2A streaming endpoint + + +async def test_agent_card_shows_streaming_enabled(): + """Test that agent card reflects streaming capability.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + response = await client.get('/.well-known/agent-card.json') + assert response.status_code == 200 + data = response.json() + assert data['capabilities']['streaming'] is True + + +async def test_message_stream_returns_sse(): + """Test that message/stream returns SSE response.""" + async with create_streaming_app(response_text='Test response') as app: + async with create_test_client(app) as client: + payload = { + 'jsonrpc': '2.0', + 'id': 'test-request-1', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'kind': 'message', + 'messageId': 'user-msg-1', + } + }, + } + + # Use streaming request + async with client.stream('POST', '/', json=payload) as response: + assert response.status_code == 200 + assert 'text/event-stream' in response.headers.get('content-type', '') + + events: list[Any] = [] + async for line in response.aiter_lines(): + if line.startswith('data: '): + data = json.loads(line[6:]) + events.append(data) + + # Should have multiple events + assert len(events) >= 3 # task, working status, completed status + + # First event should be the task + assert events[0]['result']['kind'] == 'task' + assert events[0]['result']['status']['state'] == 'submitted' + + # Should have a working status + working_events = [ + e + for e in events + if e['result'].get('kind') == 'status-update' + and e['result'].get('status', {}).get('state') == 'working' + ] + assert len(working_events) >= 1 + + # Should have a message with the response + message_events = [ + e for e in events if e['result'].get('kind') == 'message' and e['result'].get('role') == 'agent' + ] + assert len(message_events) >= 1 + assert 'Test response' in message_events[0]['result']['parts'][0]['text'] + + # Last event should be final status + final_events = [e for e in events if e['result'].get('final') is True] + assert len(final_events) >= 1 + + +async def test_message_stream_includes_context_id(): + """Test that streamed events include context_id.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + context_id = str(uuid.uuid4()) + payload = { + 'jsonrpc': '2.0', + 'id': 'test-2', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hi'}], + 'kind': 'message', + 'messageId': 'msg-1', + 'contextId': context_id, + } + }, + } + + async with client.stream('POST', '/', json=payload) as response: + events: list[Any] = [] + async for line in response.aiter_lines(): + if line.startswith('data: '): + events.append(json.loads(line[6:])) + + # All events should have the same context_id + for event in events: + result = event['result'] + event_context = result.get('contextId') or result.get('context_id') + assert event_context == context_id + + +async def test_message_send_still_works(): + """Test that non-streaming message/send still works.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + payload = { + 'jsonrpc': '2.0', + 'id': 'test-3', + 'method': 'message/send', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'kind': 'message', + 'messageId': 'msg-1', + } + }, + } + + response = await client.post('/', json=payload) + assert response.status_code == 200 + data = response.json() + assert data['result']['kind'] == 'task' + assert data['result']['status']['state'] == 'submitted' + + +# Tests for TaskManager stream_message + + +async def test_stream_message_yields_events_in_order(): + """Test that stream_message yields events: task, status updates, messages, final status.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + + # Use a longer delay to ensure we capture the initial task before it completes + worker = EchoWorker(broker=broker, storage=storage, delay=0.2) + + async with broker: + async with worker.run(): + from fasta2a.task_manager import TaskManager + + task_manager = TaskManager(broker=broker, storage=storage) + async with task_manager: + request: StreamMessageRequest = { + 'jsonrpc': '2.0', + 'id': 'req-1', + 'method': 'message/stream', + 'params': MessageSendParams( + message=Message( + role='user', + parts=[TextPart(kind='text', text='Test')], + kind='message', + message_id='msg-1', + ) + ), + } + + events: list[StreamEvent] = [] + async for event in task_manager.stream_message(request): + events.append(event) + # Stop when we get the final status + if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final'): + break + + # Should have at least: task, working status, message, final status + assert len(events) >= 4 + + # First event should be the task (submitted state) + assert events[0]['kind'] == 'task' + + # Should have a working status update + working_events = [ + e + for e in events + if e.get('kind') == 'status-update' and e.get('status', {}).get('state') == 'working' + ] + assert len(working_events) >= 1 + + # Should have an agent message + message_events = [e for e in events if e.get('kind') == 'message' and e.get('role') == 'agent'] + assert len(message_events) >= 1 + + # Last event should be final status + assert events[-1]['kind'] == 'status-update' + assert events[-1]['final'] is True + assert events[-1]['status']['state'] == 'completed' diff --git a/uv.lock b/uv.lock index a77fc88..5015736 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -60,6 +60,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "backrefs" version = "5.9" @@ -439,7 +448,10 @@ dependencies = [ { name = "eval-type-backport", marker = "python_full_version < '3.10'" }, { name = "opentelemetry-api" }, { name = "pydantic" }, - { name = "starlette" }, + { name = "sse-starlette", version = "3.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "sse-starlette", version = "3.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "starlette", version = "0.49.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "starlette", version = "0.52.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] [package.optional-dependencies] @@ -452,9 +464,13 @@ dev = [ { name = "asgi-lifespan" }, { name = "coverage" }, { name = "httpx" }, + { name = "httpx-sse" }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pytest-mock" }, { name = "ruff" }, ] docs = [ @@ -470,6 +486,7 @@ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "pydantic", specifier = ">=2.10" }, + { name = "sse-starlette", specifier = ">=2.0.0" }, { name = "starlette", specifier = ">0.29.0" }, ] provides-extras = ["logfire"] @@ -479,9 +496,12 @@ dev = [ { name = "asgi-lifespan" }, { name = "coverage" }, { name = "httpx" }, + { name = "httpx-sse" }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "ruff" }, ] docs = [ @@ -564,6 +584,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1356,6 +1385,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.10'" }, + { name = "pytest", marker = "python_full_version < '3.10'" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, +] + +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version == '3.10.*'" }, + { name = "pytest", marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.10' and python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1514,17 +1589,68 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/9c/0e6afc12c269578be5c0c1c9f4b49a8d32770a080260c333ac04cc1c832d/soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4", size = 36677, upload-time = "2025-04-20T18:50:07.196Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "anyio", marker = "python_full_version < '3.10'" }, + { name = "starlette", version = "0.49.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/5a/a1b3e05da1de8dd17fcbe2d6e66725bff6a8d96293db53cd0d89eac84f81/sse_starlette-3.3.0.tar.gz", hash = "sha256:fdf4a84e2230b12daa3a5a4a1a651586debdefd6eb2fbf812554490d01326896", size = 32679, upload-time = "2026-02-28T08:30:33.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/7a/e3ff499eaee52cf4c78bbf483a3ec536b0b04290e482d5a740cd2729fdcb/sse_starlette-3.3.0-py3-none-any.whl", hash = "sha256:eb5acdac069c7c8b2ce2d3c447b58016da9737ff8a8f475438d88397d49883ef", size = 14276, upload-time = "2026-02-28T08:30:31.771Z" }, +] + +[[package]] +name = "sse-starlette" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "starlette", version = "0.52.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/9f/c3695c2d2d4ef70072c3a06992850498b01c6bc9be531950813716b426fa/sse_starlette-3.3.2.tar.gz", hash = "sha256:678fca55a1945c734d8472a6cad186a55ab02840b4f6786f5ee8770970579dcd", size = 32326, upload-time = "2026-02-28T11:24:34.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/28/8cb142d3fe80c4a2d8af54ca0b003f47ce0ba920974e7990fa6e016402d1/sse_starlette-3.3.2-py3-none-any.whl", hash = "sha256:5c3ea3dad425c601236726af2f27689b74494643f57017cafcb6f8c9acfbb862", size = 14270, upload-time = "2026-02-28T11:24:32.984Z" }, +] + [[package]] name = "starlette" -version = "0.47.1" +version = "0.49.3" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] dependencies = [ - { name = "anyio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "anyio", marker = "python_full_version < '3.10'" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/1a/608df0b10b53b0beb96a37854ee05864d182ddd4b1156a22f1ad3860425a/starlette-0.49.3.tar.gz", hash = "sha256:1c14546f299b5901a1ea0e34410575bc33bbd741377a10484a54445588d00284", size = 2655031, upload-time = "2025-11-01T15:12:26.13Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/e0/021c772d6a662f43b63044ab481dc6ac7592447605b5b35a957785363122/starlette-0.49.3-py3-none-any.whl", hash = "sha256:b579b99715fdc2980cf88c8ec96d3bf1ce16f5a8051a7c2b84ef9b1cdecaea2f", size = 74340, upload-time = "2025-11-01T15:12:24.387Z" }, +] + +[[package]] +name = "starlette" +version = "0.52.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.10' and python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072, upload-time = "2025-06-21T04:03:17.337Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747, upload-time = "2025-06-21T04:03:15.705Z" }, + { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] [[package]]