From 3e60fa16b628ac72341015c3f1c31ff56cb5a2cd Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 7 Jan 2026 09:29:31 +0100 Subject: [PATCH 01/12] fix: streaming --- fasta2a/__init__.py | 6 +-- fasta2a/applications.py | 27 +++++++++++- fasta2a/broker.py | 95 +++++++++++++++++++++++++++++++++++++++-- fasta2a/schema.py | 6 +++ fasta2a/storage.py | 94 +++++++++++++++++++++++++++++++++++++++- fasta2a/task_manager.py | 45 +++++++++++++++++-- pyproject.toml | 4 ++ 7 files changed, 264 insertions(+), 13 deletions(-) diff --git a/fasta2a/__init__.py b/fasta2a/__init__.py index 4a8b106..57a02f5 100644 --- a/fasta2a/__init__.py +++ b/fasta2a/__init__.py @@ -1,7 +1,7 @@ from .applications import FastA2A from .broker import Broker -from .schema import Skill -from .storage import Storage +from .schema import Skill, StreamEvent +from .storage import Storage, StreamingStorageWrapper from .worker import Worker -__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker'] +__all__ = ['FastA2A', 'Skill', 'Storage', 'StreamingStorageWrapper', 'Broker', 'Worker', 'StreamEvent'] diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 958530d..896384c 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -1,10 +1,12 @@ from __future__ import annotations as _annotations +import json from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from pathlib import Path from typing import Any +from sse_starlette import EventSourceResponse from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.requests import Request @@ -21,6 +23,8 @@ a2a_request_ta, a2a_response_ta, agent_card_ta, + stream_event_ta, + stream_message_request_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -42,6 +46,7 @@ def __init__( provider: AgentProvider | None = None, skills: list[Skill] | None = None, docs_url: str | None = '/docs', + streaming: bool = False, # Starlette debug: bool = False, routes: Sequence[Route] | None = None, @@ -67,6 +72,7 @@ def __init__( self.provider = provider self.skills = skills or [] self.docs_url = docs_url + self.streaming = streaming # NOTE: For now, I don't think there's any reason to support any other input/output modes. self.default_input_modes = ['application/json'] self.default_output_modes = ['application/json'] @@ -100,7 +106,7 @@ async def _agent_card_endpoint(self, request: Request) -> Response: default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, capabilities=AgentCapabilities( - streaming=False, push_notifications=False, state_transition_history=False + streaming=self.streaming, push_notifications=False, state_transition_history=False ), ) if self.provider is not None: @@ -131,6 +137,25 @@ async def _agent_run_endpoint(self, request: Request) -> Response: if a2a_request['method'] == 'message/send': jsonrpc_response = await self.task_manager.send_message(a2a_request) + elif a2a_request['method'] == 'message/stream': + # Parse the streaming request + stream_request = stream_message_request_ta.validate_json(data) + + # Create an async generator wrapper that formats events as JSON-RPC responses + async def sse_generator(): + request_id = stream_request.get('id') + async for event in self.task_manager.stream_message(stream_request): + # Serialize event to ensure proper camelCase conversion + event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) + + # Wrap in JSON-RPC response + jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict} + + # Convert to JSON string + yield json.dumps(jsonrpc_response) + + # Return SSE response + return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/broker.py b/fasta2a/broker.py index c84b738..7ca631f 100644 --- a/fasta2a/broker.py +++ b/fasta2a/broker.py @@ -7,11 +7,12 @@ from typing import Annotated, Any, Generic, Literal, TypeVar import anyio +from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import Span, get_current_span, get_tracer from pydantic import Discriminator from typing_extensions import Self, TypedDict -from .schema import TaskIdParams, TaskSendParams +from .schema import StreamEvent, TaskIdParams, TaskSendParams tracer = get_tracer(__name__) @@ -30,12 +31,32 @@ class Broker(ABC): @abstractmethod async def run_task(self, params: TaskSendParams) -> None: """Send a task to be executed by the worker.""" - raise NotImplementedError('send_run_task is not implemented yet.') + ... @abstractmethod async def cancel_task(self, params: TaskIdParams) -> None: """Cancel a task.""" - raise NotImplementedError('send_cancel_task is not implemented yet.') + ... + + @abstractmethod + async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: + """Send a streaming event from worker to subscribers. + + This is used by workers to publish status updates, messages, and artifacts + during task execution. Events are forwarded to all active subscribers of + the given task_id. + """ + ... + + @abstractmethod + def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Subscribe to streaming events for a specific task. + + Returns an async iterator that yields events published by workers for the + given task_id. The iterator completes when a TaskStatusUpdateEvent with + final=True is received or the subscription is cancelled. + """ + ... @abstractmethod async def __aenter__(self) -> Self: ... @@ -73,6 +94,10 @@ class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]): class InMemoryBroker(Broker): """A broker that schedules tasks in memory.""" + def __init__(self) -> None: + self._event_subscribers: dict[str, list[MemoryObjectSendStream[StreamEvent]]] = {} + self._subscriber_lock: anyio.Lock | None = None + async def __aenter__(self): self.aexit_stack = AsyncExitStack() await self.aexit_stack.__aenter__() @@ -81,6 +106,8 @@ async def __aenter__(self): await self.aexit_stack.enter_async_context(self._read_stream) await self.aexit_stack.enter_async_context(self._write_stream) + self._subscriber_lock = anyio.Lock() + return self async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): @@ -96,3 +123,65 @@ async def receive_task_operations(self) -> AsyncIterator[TaskOperation]: """Receive task operations from the broker.""" async for task_operation in self._read_stream: yield task_operation + + async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: + """Send a streaming event from worker to subscribers.""" + assert self._subscriber_lock is not None, 'Broker not initialized' + + async with self._subscriber_lock: + subscribers = self._event_subscribers.get(task_id, []) + if not subscribers: + return + + # Send event to all subscribers, removing closed streams + active_subscribers: list[MemoryObjectSendStream[StreamEvent]] = [] + for stream in subscribers: + try: + await stream.send(event) + active_subscribers.append(stream) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + # Subscriber disconnected, remove from list + pass + + # Update subscriber list with only active ones + if active_subscribers: + self._event_subscribers[task_id] = active_subscribers + elif task_id in self._event_subscribers: + # No active subscribers left, clean up + del self._event_subscribers[task_id] + + async def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Subscribe to streaming events for a specific task.""" + assert self._subscriber_lock is not None, 'Broker not initialized' + + # Create a new stream for this subscriber + send_stream, receive_stream = anyio.create_memory_object_stream[StreamEvent](max_buffer_size=100) + + # Register the subscriber + async with self._subscriber_lock: + if task_id not in self._event_subscribers: + self._event_subscribers[task_id] = [] + self._event_subscribers[task_id].append(send_stream) + + try: + async with receive_stream: + async for event in receive_stream: + yield event + + # Check if this is a final status update + if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final', False): + break + finally: + # Clean up subscription on exit + async with self._subscriber_lock: + if task_id in self._event_subscribers: + try: + self._event_subscribers[task_id].remove(send_stream) + if not self._event_subscribers[task_id]: + del self._event_subscribers[task_id] + except ValueError: + # Already removed + pass + + # Close the send stream + await send_stream.aclose() diff --git a/fasta2a/schema.py b/fasta2a/schema.py index 37ffb86..6e53701 100644 --- a/fasta2a/schema.py +++ b/fasta2a/schema.py @@ -808,3 +808,9 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) stream_message_response_ta: TypeAdapter[StreamMessageResponse] = TypeAdapter(StreamMessageResponse) + +# Type for streaming events (used by broker and task manager) +StreamEvent = Union[Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent] +"""A streaming event that can be sent during message/stream requests.""" + +stream_event_ta: TypeAdapter[StreamEvent] = TypeAdapter(StreamEvent) diff --git a/fasta2a/storage.py b/fasta2a/storage.py index 58934c1..e6ccb15 100644 --- a/fasta2a/storage.py +++ b/fasta2a/storage.py @@ -5,11 +5,22 @@ import uuid from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar -from .schema import Artifact, Message, Task, TaskState, TaskStatus +from .schema import ( + Artifact, + Message, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) + +if TYPE_CHECKING: + from .broker import Broker ContextT = TypeVar('ContextT', default=Any) @@ -129,3 +140,82 @@ async def update_context(self, context_id: str, context: ContextT) -> None: async def load_context(self, context_id: str) -> ContextT | None: """Retrieve the stored context given the `context_id`.""" return self.contexts.get(context_id) + + +class StreamingStorageWrapper(Storage[ContextT]): + """A storage wrapper that publishes streaming events when tasks are updated. + + This wrapper intercepts update_task calls and publishes TaskStatusUpdateEvent + and TaskArtifactUpdateEvent to the broker, enabling SSE streaming without + modifying the underlying storage or worker implementations. + """ + + def __init__(self, storage: Storage[ContextT], broker: Broker): + self._storage = storage + self._broker = broker + + async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: + return await self._storage.load_task(task_id, history_length) + + async def submit_task(self, context_id: str, message: Message) -> Task: + return await self._storage.submit_task(context_id, message) + + async def update_task( + self, + task_id: str, + state: TaskState, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, + ) -> Task: + """Update task and publish streaming events.""" + # Update the underlying storage first + task = await self._storage.update_task(task_id, state, new_artifacts, new_messages) + + # Determine if this is a final state + final = state in ('completed', 'failed', 'canceled') + + # For non-final updates, publish status first + if not final: + status_event = TaskStatusUpdateEvent( + kind='status-update', + task_id=task_id, + context_id=task['context_id'], + status=task['status'], + final=False, + ) + await self._broker.send_stream_event(task_id, status_event) + + # Publish message events BEFORE final status (so subscriber receives them) + if new_messages: + for message in new_messages: + await self._broker.send_stream_event(task_id, message) + + # Publish artifact events + if new_artifacts: + for artifact in new_artifacts: + artifact_event = TaskArtifactUpdateEvent( + kind='artifact-update', + task_id=task_id, + context_id=task['context_id'], + artifact=artifact, + ) + await self._broker.send_stream_event(task_id, artifact_event) + + # For final updates, publish status LAST (after messages and artifacts) + if final: + status_event = TaskStatusUpdateEvent( + kind='status-update', + task_id=task_id, + context_id=task['context_id'], + status=task['status'], + final=True, + ) + await self._broker.send_stream_event(task_id, status_event) + + return task + + async def load_context(self, context_id: str) -> ContextT | None: + return await self._storage.load_context(context_id) + + async def update_context(self, context_id: str, context: ContextT) -> None: + await self._storage.update_context(context_id, context) diff --git a/fasta2a/task_manager.py b/fasta2a/task_manager.py index 54ab709..9457ce8 100644 --- a/fasta2a/task_manager.py +++ b/fasta2a/task_manager.py @@ -60,7 +60,9 @@ from __future__ import annotations as _annotations +import asyncio import uuid +from collections.abc import AsyncGenerator from contextlib import AsyncExitStack from dataclasses import dataclass, field from typing import Any @@ -78,8 +80,8 @@ SendMessageResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, + StreamEvent, StreamMessageRequest, - StreamMessageResponse, TaskNotFoundError, TaskSendParams, ) @@ -156,9 +158,44 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: ) return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - async def stream_message(self, request: StreamMessageRequest) -> StreamMessageResponse: - """Stream messages using Server-Sent Events.""" - raise NotImplementedError('message/stream method is not implemented yet.') + async def stream_message(self, request: StreamMessageRequest) -> AsyncGenerator[StreamEvent, None]: + """Handle a streaming message request. + + This method: + 1. Creates and submits a new task + 2. Yields the initial task object + 3. Subscribes to the broker's event stream + 4. Starts task execution asynchronously + 5. Streams all events until completion + """ + # Extract parameters + params = request['params'] + message = params['message'] + context_id = message.get('context_id', str(uuid.uuid4())) + + # Create and submit the task + task = await self.storage.submit_task(context_id, message) + + # Yield the initial task + yield task + + # Prepare broker params + broker_params: TaskSendParams = {'id': task['id'], 'context_id': context_id, 'message': message} + config = params.get('configuration', {}) + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length + + metadata = params.get('metadata') + if metadata is not None: + broker_params['metadata'] = metadata + + # Start task execution in background + asyncio.create_task(self.broker.run_task(broker_params)) + + # Stream events from broker + async for event in self.broker.subscribe_to_stream(task['id']): + yield event async def set_task_push_notification( self, request: SetTaskPushNotificationRequest diff --git a/pyproject.toml b/pyproject.toml index 2af3b7d..ee2a70b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "pydantic>=2.10", "opentelemetry-api>=1.28.0", "eval_type_backport>=0.2.2; python_version <= '3.9'", + "sse-starlette>=2.0.0", ] [project.optional-dependencies] @@ -58,8 +59,11 @@ dev = [ "asgi-lifespan", "coverage", "httpx", + "httpx-sse", "inline-snapshot", "pytest", + "pytest-asyncio", + "pytest-mock", "ruff", "pyright", ] From 8a5dbfed30adda56968d57add522d9624a891320 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 7 Jan 2026 09:40:46 +0100 Subject: [PATCH 02/12] test: streaming --- tests/test_streaming.py | 521 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 521 insertions(+) create mode 100644 tests/test_streaming.py diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..5dfd356 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,521 @@ +"""Tests for the SSE streaming feature in fasta2a. + +This module tests: +- StreamingStorageWrapper event publishing +- InMemoryBroker pub/sub for streaming +- TaskManager stream_message method +- FastA2A message/stream endpoint +""" + +from __future__ import annotations as _annotations + +import asyncio +import json +import uuid +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager + +from fasta2a import FastA2A, Worker +from fasta2a.broker import InMemoryBroker +from fasta2a.schema import ( + Artifact, + Message, + TaskIdParams, + TaskSendParams, + TextPart, +) +from fasta2a.storage import InMemoryStorage, StreamingStorageWrapper + +pytestmark = pytest.mark.anyio + + +# Test fixtures and helpers + +@asynccontextmanager +async def create_test_client(app: FastA2A): + """Create a test client for the FastA2A app.""" + async with LifespanManager(app=app) as manager: + transport = httpx.ASGITransport(app=manager.app) + async with httpx.AsyncClient(transport=transport, base_url='http://testclient') as client: + yield client + + +Context = list[Message] + + +@dataclass +class EchoWorker(Worker[Context]): + """A simple worker for testing that echoes messages.""" + + response_text: str = "Hello from test worker!" + delay: float = 0.1 + + async def run_task(self, params: TaskSendParams) -> None: + task = await self.storage.load_task(params['id']) + assert task is not None + + # Update to working state + await self.storage.update_task(task['id'], state='working') + + # Simulate some work + await asyncio.sleep(self.delay) + + # Create response message + message = Message( + role='agent', + parts=[TextPart(text=self.response_text, kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + + # Create artifact + artifact = Artifact( + artifact_id=str(uuid.uuid4()), + name='result', + parts=[TextPart(text=self.response_text, kind='text')], + ) + + # Complete the task with message and artifact + await self.storage.update_task( + task['id'], + state='completed', + new_messages=[message], + new_artifacts=[artifact], + ) + + async def cancel_task(self, params: TaskIdParams) -> None: + await self.storage.update_task(params['id'], state='canceled') + + def build_message_history(self, history: list[Message]) -> list[Any]: + return history + + def build_artifacts(self, result: Any) -> list[Artifact]: + return [] + + +@asynccontextmanager +async def create_streaming_app(response_text: str = "Hello!"): + """Create a FastA2A app with streaming enabled.""" + broker = InMemoryBroker() + base_storage = InMemoryStorage() + streaming_storage = StreamingStorageWrapper(base_storage, broker) + + worker = EchoWorker( + broker=broker, + storage=streaming_storage, + response_text=response_text, + ) + + @asynccontextmanager + async def lifespan(app: FastA2A): + async with app.task_manager: + async with worker.run(): + yield + + app = FastA2A( + storage=streaming_storage, + broker=broker, + streaming=True, + lifespan=lifespan, + ) + + yield app + + +# Tests for InMemoryBroker streaming + +class TestInMemoryBrokerStreaming: + """Tests for InMemoryBroker pub/sub streaming.""" + + async def test_subscribe_and_receive_events(self): + """Test that subscribers receive events sent by send_stream_event.""" + broker = InMemoryBroker() + + async with broker: + task_id = "test-task-123" + received_events = [] + + # Start subscriber in background + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if event.get('final', False): + break + + subscriber_task = asyncio.create_task(subscriber()) + + # Give subscriber time to register + await asyncio.sleep(0.01) + + # Send events + await broker.send_stream_event(task_id, {'kind': 'status-update', 'final': False}) + await broker.send_stream_event(task_id, {'kind': 'message', 'text': 'Hello'}) + await broker.send_stream_event(task_id, {'kind': 'status-update', 'final': True}) + + # Wait for subscriber to finish + await asyncio.wait_for(subscriber_task, timeout=1.0) + + assert len(received_events) == 3 + assert received_events[0]['kind'] == 'status-update' + assert received_events[1]['kind'] == 'message' + assert received_events[2]['final'] is True + + async def test_multiple_subscribers(self): + """Test that multiple subscribers receive the same events.""" + broker = InMemoryBroker() + + async with broker: + task_id = "test-task-456" + received_1 = [] + received_2 = [] + + async def subscriber1(): + async for event in broker.subscribe_to_stream(task_id): + received_1.append(event) + if event.get('final', False): + break + + async def subscriber2(): + async for event in broker.subscribe_to_stream(task_id): + received_2.append(event) + if event.get('final', False): + break + + task1 = asyncio.create_task(subscriber1()) + task2 = asyncio.create_task(subscriber2()) + + await asyncio.sleep(0.01) + + await broker.send_stream_event(task_id, {'kind': 'test', 'final': False}) + await broker.send_stream_event(task_id, {'kind': 'done', 'final': True}) + + await asyncio.wait_for(asyncio.gather(task1, task2), timeout=1.0) + + assert len(received_1) == 2 + assert len(received_2) == 2 + + async def test_no_subscribers_doesnt_error(self): + """Test that sending events with no subscribers doesn't raise errors.""" + broker = InMemoryBroker() + + async with broker: + # Should not raise + await broker.send_stream_event("nonexistent-task", {'kind': 'test'}) + + +# Tests for StreamingStorageWrapper + +class TestStreamingStorageWrapper: + """Tests for StreamingStorageWrapper event publishing.""" + + async def test_publishes_status_update_on_working(self): + """Test that updating to 'working' publishes a status-update event.""" + broker = InMemoryBroker() + base_storage = InMemoryStorage() + storage = StreamingStorageWrapper(base_storage, broker) + + async with broker: + # Create a task + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + # Stop after first event for this test + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + # Update to working + await storage.update_task(task_id, state='working') + + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received_events) == 1 + assert received_events[0]['kind'] == 'status-update' + assert received_events[0]['status']['state'] == 'working' + assert received_events[0]['final'] is False + + async def test_publishes_message_before_final_status(self): + """Test that messages are published before the final status update.""" + broker = InMemoryBroker() + base_storage = InMemoryStorage() + storage = StreamingStorageWrapper(base_storage, broker) + + async with broker: + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if event.get('final', False): + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + # Complete with a new message + agent_message = Message( + role='agent', + parts=[TextPart(text='Hi there!', kind='text')], + kind='message', + message_id='msg-2', + ) + await storage.update_task(task_id, state='completed', new_messages=[agent_message]) + + await asyncio.wait_for(sub_task, timeout=1.0) + + # Should have: message, then final status + assert len(received_events) == 2 + assert received_events[0]['kind'] == 'message' + assert received_events[0]['role'] == 'agent' + assert received_events[1]['kind'] == 'status-update' + assert received_events[1]['final'] is True + + async def test_publishes_artifacts(self): + """Test that artifacts are published.""" + broker = InMemoryBroker() + base_storage = InMemoryStorage() + storage = StreamingStorageWrapper(base_storage, broker) + + async with broker: + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if event.get('final', False): + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + artifact = Artifact( + artifact_id='art-1', + name='result', + parts=[TextPart(text='Result data', kind='text')], + ) + await storage.update_task(task_id, state='completed', new_artifacts=[artifact]) + + await asyncio.wait_for(sub_task, timeout=1.0) + + # Should have: artifact, then final status + assert len(received_events) == 2 + assert received_events[0]['kind'] == 'artifact-update' + assert received_events[0]['artifact']['name'] == 'result' + assert received_events[1]['kind'] == 'status-update' + assert received_events[1]['final'] is True + + +# Tests for FastA2A streaming endpoint + +class TestFastA2AStreaming: + """Tests for the FastA2A message/stream SSE endpoint.""" + + async def test_agent_card_shows_streaming_enabled(self): + """Test that agent card reflects streaming capability.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + response = await client.get('/.well-known/agent-card.json') + assert response.status_code == 200 + data = response.json() + assert data['capabilities']['streaming'] is True + + async def test_message_stream_returns_sse(self): + """Test that message/stream returns SSE response.""" + async with create_streaming_app(response_text="Test response") as app: + async with create_test_client(app) as client: + payload = { + 'jsonrpc': '2.0', + 'id': 'test-request-1', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'kind': 'message', + 'messageId': 'user-msg-1', + } + } + } + + # Use streaming request + async with client.stream('POST', '/', json=payload) as response: + assert response.status_code == 200 + assert 'text/event-stream' in response.headers.get('content-type', '') + + events = [] + async for line in response.aiter_lines(): + if line.startswith('data: '): + data = json.loads(line[6:]) + events.append(data) + + # Should have multiple events + assert len(events) >= 3 # task, working status, completed status + + # First event should be the task + assert events[0]['result']['kind'] == 'task' + assert events[0]['result']['status']['state'] == 'submitted' + + # Should have a working status + working_events = [e for e in events if e['result'].get('kind') == 'status-update' and e['result'].get('status', {}).get('state') == 'working'] + assert len(working_events) >= 1 + + # Should have a message with the response + message_events = [e for e in events if e['result'].get('kind') == 'message' and e['result'].get('role') == 'agent'] + assert len(message_events) >= 1 + assert 'Test response' in message_events[0]['result']['parts'][0]['text'] + + # Last event should be final status + final_events = [e for e in events if e['result'].get('final') is True] + assert len(final_events) >= 1 + + async def test_message_stream_includes_context_id(self): + """Test that streamed events include context_id.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + context_id = str(uuid.uuid4()) + payload = { + 'jsonrpc': '2.0', + 'id': 'test-2', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hi'}], + 'kind': 'message', + 'messageId': 'msg-1', + 'contextId': context_id, + } + } + } + + async with client.stream('POST', '/', json=payload) as response: + events = [] + async for line in response.aiter_lines(): + if line.startswith('data: '): + events.append(json.loads(line[6:])) + + # All events should have the same context_id + for event in events: + result = event['result'] + event_context = result.get('contextId') or result.get('context_id') + assert event_context == context_id + + async def test_message_send_still_works(self): + """Test that non-streaming message/send still works.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + payload = { + 'jsonrpc': '2.0', + 'id': 'test-3', + 'method': 'message/send', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'kind': 'message', + 'messageId': 'msg-1', + } + } + } + + response = await client.post('/', json=payload) + assert response.status_code == 200 + data = response.json() + assert data['result']['kind'] == 'task' + assert data['result']['status']['state'] == 'submitted' + + +# Tests for TaskManager stream_message + +class TestTaskManagerStreamMessage: + """Tests for TaskManager.stream_message method.""" + + async def test_stream_message_yields_events_in_order(self): + """Test that stream_message yields events: task, status updates, messages, final status.""" + broker = InMemoryBroker() + base_storage = InMemoryStorage() + streaming_storage = StreamingStorageWrapper(base_storage, broker) + + # Use a longer delay to ensure we capture the initial task before it completes + worker = EchoWorker(broker=broker, storage=streaming_storage, delay=0.2) + + async with broker: + async with worker.run(): + from fasta2a.task_manager import TaskManager + + task_manager = TaskManager(broker=broker, storage=streaming_storage) + async with task_manager: + request = { + 'jsonrpc': '2.0', + 'id': 'req-1', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Test'}], + 'kind': 'message', + 'messageId': 'msg-1', + } + } + } + + events = [] + async for event in task_manager.stream_message(request): + events.append(event) + # Stop when we get the final status + if event.get('kind') == 'status-update' and event.get('final'): + break + + # Should have at least: task, working status, message, final status + assert len(events) >= 4 + + # First event should be the task (submitted state) + assert events[0]['kind'] == 'task' + + # Should have a working status update + working_events = [e for e in events if e.get('kind') == 'status-update' and e.get('status', {}).get('state') == 'working'] + assert len(working_events) >= 1 + + # Should have an agent message + message_events = [e for e in events if e.get('kind') == 'message' and e.get('role') == 'agent'] + assert len(message_events) >= 1 + + # Last event should be final status + assert events[-1]['kind'] == 'status-update' + assert events[-1]['final'] is True + assert events[-1]['status']['state'] == 'completed' From 5903cff8cf4fcefdc7f6ba50b89c6559948fbc53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=A9onore=20Charles?= Date: Fri, 9 Jan 2026 16:19:45 +0100 Subject: [PATCH 03/12] fix CI --- fasta2a/storage.py | 2 +- tests/test_streaming.py | 185 +++++++++++++++++++++++++++------------- 2 files changed, 129 insertions(+), 58 deletions(-) diff --git a/fasta2a/storage.py b/fasta2a/storage.py index e6ccb15..26ee2f5 100644 --- a/fasta2a/storage.py +++ b/fasta2a/storage.py @@ -144,7 +144,7 @@ async def load_context(self, context_id: str) -> ContextT | None: class StreamingStorageWrapper(Storage[ContextT]): """A storage wrapper that publishes streaming events when tasks are updated. - + This wrapper intercepts update_task calls and publishes TaskStatusUpdateEvent and TaskArtifactUpdateEvent to the broker, enabling SSE streaming without modifying the underlying storage or worker implementations. diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 5dfd356..d34a95f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -25,8 +25,13 @@ from fasta2a.schema import ( Artifact, Message, + MessageSendParams, + StreamEvent, + StreamMessageRequest, TaskIdParams, TaskSendParams, + TaskStatus, + TaskStatusUpdateEvent, TextPart, ) from fasta2a.storage import InMemoryStorage, StreamingStorageWrapper @@ -36,6 +41,7 @@ # Test fixtures and helpers + @asynccontextmanager async def create_test_client(app: FastA2A): """Create a test client for the FastA2A app.""" @@ -52,7 +58,7 @@ async def create_test_client(app: FastA2A): class EchoWorker(Worker[Context]): """A simple worker for testing that echoes messages.""" - response_text: str = "Hello from test worker!" + response_text: str = 'Hello from test worker!' delay: float = 0.1 async def run_task(self, params: TaskSendParams) -> None: @@ -99,12 +105,12 @@ def build_artifacts(self, result: Any) -> list[Artifact]: @asynccontextmanager -async def create_streaming_app(response_text: str = "Hello!"): +async def create_streaming_app(response_text: str = 'Hello!'): """Create a FastA2A app with streaming enabled.""" broker = InMemoryBroker() base_storage = InMemoryStorage() streaming_storage = StreamingStorageWrapper(base_storage, broker) - + worker = EchoWorker( broker=broker, storage=streaming_storage, @@ -129,22 +135,24 @@ async def lifespan(app: FastA2A): # Tests for InMemoryBroker streaming + class TestInMemoryBrokerStreaming: """Tests for InMemoryBroker pub/sub streaming.""" async def test_subscribe_and_receive_events(self): """Test that subscribers receive events sent by send_stream_event.""" broker = InMemoryBroker() - + async with broker: - task_id = "test-task-123" - received_events = [] + task_id = 'test-task-123' + received_events: list[StreamEvent] = [] # Start subscriber in background async def subscriber(): async for event in broker.subscribe_to_stream(task_id): received_events.append(event) - if event.get('final', False): + # Check if event has a 'final' field (for status events) + if isinstance(event, dict) and event.get('final', False): break subscriber_task = asyncio.create_task(subscriber()) @@ -153,37 +161,58 @@ async def subscriber(): await asyncio.sleep(0.01) # Send events - await broker.send_stream_event(task_id, {'kind': 'status-update', 'final': False}) - await broker.send_stream_event(task_id, {'kind': 'message', 'text': 'Hello'}) - await broker.send_stream_event(task_id, {'kind': 'status-update', 'final': True}) + status_event_1 = TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ) + message_event = Message( + message_id='test-msg-1', + role='agent', + parts=[TextPart(kind='text', text='Hello')], + kind='message', + ) + status_event_2 = TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='completed'), + final=True, + ) + await broker.send_stream_event(task_id, status_event_1) + await broker.send_stream_event(task_id, message_event) + await broker.send_stream_event(task_id, status_event_2) # Wait for subscriber to finish await asyncio.wait_for(subscriber_task, timeout=1.0) assert len(received_events) == 3 assert received_events[0]['kind'] == 'status-update' - assert received_events[1]['kind'] == 'message' - assert received_events[2]['final'] is True + assert received_events[1]['kind'] == 'message' # Message has both kind and role + # Only check 'final' on status update events + assert received_events[2]['kind'] == 'status-update' and received_events[2]['final'] is True async def test_multiple_subscribers(self): """Test that multiple subscribers receive the same events.""" broker = InMemoryBroker() - + async with broker: - task_id = "test-task-456" - received_1 = [] - received_2 = [] + task_id = 'test-task-456' + received_1: list[StreamEvent] = [] + received_2: list[StreamEvent] = [] async def subscriber1(): async for event in broker.subscribe_to_stream(task_id): received_1.append(event) - if event.get('final', False): + if isinstance(event, dict) and event.get('final', False): break async def subscriber2(): async for event in broker.subscribe_to_stream(task_id): received_2.append(event) - if event.get('final', False): + if isinstance(event, dict) and event.get('final', False): break task1 = asyncio.create_task(subscriber1()) @@ -191,8 +220,26 @@ async def subscriber2(): await asyncio.sleep(0.01) - await broker.send_stream_event(task_id, {'kind': 'test', 'final': False}) - await broker.send_stream_event(task_id, {'kind': 'done', 'final': True}) + await broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ), + ) + await broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='completed'), + final=True, + ), + ) await asyncio.wait_for(asyncio.gather(task1, task2), timeout=1.0) @@ -202,14 +249,24 @@ async def subscriber2(): async def test_no_subscribers_doesnt_error(self): """Test that sending events with no subscribers doesn't raise errors.""" broker = InMemoryBroker() - + async with broker: # Should not raise - await broker.send_stream_event("nonexistent-task", {'kind': 'test'}) + await broker.send_stream_event( + 'nonexistent-task', + TaskStatusUpdateEvent( + task_id='nonexistent-task', + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ), + ) # Tests for StreamingStorageWrapper + class TestStreamingStorageWrapper: """Tests for StreamingStorageWrapper event publishing.""" @@ -218,7 +275,7 @@ async def test_publishes_status_update_on_working(self): broker = InMemoryBroker() base_storage = InMemoryStorage() storage = StreamingStorageWrapper(base_storage, broker) - + async with broker: # Create a task message = Message( @@ -230,7 +287,7 @@ async def test_publishes_status_update_on_working(self): task = await storage.submit_task('ctx-1', message) task_id = task['id'] - received_events = [] + received_events: list[StreamEvent] = [] async def subscriber(): async for event in broker.subscribe_to_stream(task_id): @@ -256,7 +313,7 @@ async def test_publishes_message_before_final_status(self): broker = InMemoryBroker() base_storage = InMemoryStorage() storage = StreamingStorageWrapper(base_storage, broker) - + async with broker: message = Message( role='user', @@ -267,12 +324,12 @@ async def test_publishes_message_before_final_status(self): task = await storage.submit_task('ctx-1', message) task_id = task['id'] - received_events = [] + received_events: list[StreamEvent] = [] async def subscriber(): async for event in broker.subscribe_to_stream(task_id): received_events.append(event) - if event.get('final', False): + if isinstance(event, dict) and event.get('final', False): break sub_task = asyncio.create_task(subscriber()) @@ -301,7 +358,7 @@ async def test_publishes_artifacts(self): broker = InMemoryBroker() base_storage = InMemoryStorage() storage = StreamingStorageWrapper(base_storage, broker) - + async with broker: message = Message( role='user', @@ -312,12 +369,12 @@ async def test_publishes_artifacts(self): task = await storage.submit_task('ctx-1', message) task_id = task['id'] - received_events = [] + received_events: list[StreamEvent] = [] async def subscriber(): async for event in broker.subscribe_to_stream(task_id): received_events.append(event) - if event.get('final', False): + if isinstance(event, dict) and event.get('final', False): break sub_task = asyncio.create_task(subscriber()) @@ -335,13 +392,15 @@ async def subscriber(): # Should have: artifact, then final status assert len(received_events) == 2 assert received_events[0]['kind'] == 'artifact-update' - assert received_events[0]['artifact']['name'] == 'result' + artifact = received_events[0]['artifact'] + assert artifact.get('name') == 'result' # Use .get() for NotRequired fields assert received_events[1]['kind'] == 'status-update' assert received_events[1]['final'] is True # Tests for FastA2A streaming endpoint + class TestFastA2AStreaming: """Tests for the FastA2A message/stream SSE endpoint.""" @@ -356,7 +415,7 @@ async def test_agent_card_shows_streaming_enabled(self): async def test_message_stream_returns_sse(self): """Test that message/stream returns SSE response.""" - async with create_streaming_app(response_text="Test response") as app: + async with create_streaming_app(response_text='Test response') as app: async with create_test_client(app) as client: payload = { 'jsonrpc': '2.0', @@ -369,7 +428,7 @@ async def test_message_stream_returns_sse(self): 'kind': 'message', 'messageId': 'user-msg-1', } - } + }, } # Use streaming request @@ -377,7 +436,7 @@ async def test_message_stream_returns_sse(self): assert response.status_code == 200 assert 'text/event-stream' in response.headers.get('content-type', '') - events = [] + events: list[Any] = [] async for line in response.aiter_lines(): if line.startswith('data: '): data = json.loads(line[6:]) @@ -391,11 +450,18 @@ async def test_message_stream_returns_sse(self): assert events[0]['result']['status']['state'] == 'submitted' # Should have a working status - working_events = [e for e in events if e['result'].get('kind') == 'status-update' and e['result'].get('status', {}).get('state') == 'working'] + working_events = [ + e + for e in events + if e['result'].get('kind') == 'status-update' + and e['result'].get('status', {}).get('state') == 'working' + ] assert len(working_events) >= 1 # Should have a message with the response - message_events = [e for e in events if e['result'].get('kind') == 'message' and e['result'].get('role') == 'agent'] + message_events = [ + e for e in events if e['result'].get('kind') == 'message' and e['result'].get('role') == 'agent' + ] assert len(message_events) >= 1 assert 'Test response' in message_events[0]['result']['parts'][0]['text'] @@ -420,11 +486,11 @@ async def test_message_stream_includes_context_id(self): 'messageId': 'msg-1', 'contextId': context_id, } - } + }, } async with client.stream('POST', '/', json=payload) as response: - events = [] + events: list[Any] = [] async for line in response.aiter_lines(): if line.startswith('data: '): events.append(json.loads(line[6:])) @@ -450,7 +516,7 @@ async def test_message_send_still_works(self): 'kind': 'message', 'messageId': 'msg-1', } - } + }, } response = await client.post('/', json=payload) @@ -462,6 +528,7 @@ async def test_message_send_still_works(self): # Tests for TaskManager stream_message + class TestTaskManagerStreamMessage: """Tests for TaskManager.stream_message method.""" @@ -470,51 +537,55 @@ async def test_stream_message_yields_events_in_order(self): broker = InMemoryBroker() base_storage = InMemoryStorage() streaming_storage = StreamingStorageWrapper(base_storage, broker) - + # Use a longer delay to ensure we capture the initial task before it completes worker = EchoWorker(broker=broker, storage=streaming_storage, delay=0.2) async with broker: async with worker.run(): from fasta2a.task_manager import TaskManager - + task_manager = TaskManager(broker=broker, storage=streaming_storage) async with task_manager: - request = { + request: StreamMessageRequest = { 'jsonrpc': '2.0', 'id': 'req-1', 'method': 'message/stream', - 'params': { - 'message': { - 'role': 'user', - 'parts': [{'kind': 'text', 'text': 'Test'}], - 'kind': 'message', - 'messageId': 'msg-1', - } - } + 'params': MessageSendParams( + message=Message( + role='user', + parts=[TextPart(kind='text', text='Test')], + kind='message', + message_id='msg-1', + ) + ), } - events = [] + events: list[StreamEvent] = [] async for event in task_manager.stream_message(request): events.append(event) # Stop when we get the final status - if event.get('kind') == 'status-update' and event.get('final'): + if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final'): break # Should have at least: task, working status, message, final status assert len(events) >= 4 - + # First event should be the task (submitted state) assert events[0]['kind'] == 'task' - + # Should have a working status update - working_events = [e for e in events if e.get('kind') == 'status-update' and e.get('status', {}).get('state') == 'working'] + working_events = [ + e + for e in events + if e.get('kind') == 'status-update' and e.get('status', {}).get('state') == 'working' + ] assert len(working_events) >= 1 - + # Should have an agent message message_events = [e for e in events if e.get('kind') == 'message' and e.get('role') == 'agent'] assert len(message_events) >= 1 - + # Last event should be final status assert events[-1]['kind'] == 'status-update' assert events[-1]['final'] is True From d16abca0f602c1b9b002e3ae712395e3a15226d2 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:06:30 +0100 Subject: [PATCH 04/12] simplify serialization --- .gitignore | 1 + fasta2a/applications.py | 20 +++--- uv.lock | 140 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 144 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index b3c9e26..dc20eb9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .cache __pycache__ +.coverage* \ No newline at end of file diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 896384c..9317d79 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -1,6 +1,5 @@ from __future__ import annotations as _annotations -import json from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from pathlib import Path @@ -20,11 +19,12 @@ AgentCard, AgentProvider, Skill, + StreamMessageResponse, a2a_request_ta, a2a_response_ta, agent_card_ta, - stream_event_ta, stream_message_request_ta, + stream_message_response_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -145,14 +145,14 @@ async def _agent_run_endpoint(self, request: Request) -> Response: async def sse_generator(): request_id = stream_request.get('id') async for event in self.task_manager.stream_message(stream_request): - # Serialize event to ensure proper camelCase conversion - event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) - - # Wrap in JSON-RPC response - jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict} - - # Convert to JSON string - yield json.dumps(jsonrpc_response) + jsonrpc_response = StreamMessageResponse( + jsonrpc='2.0', + id=request_id, + result=event, + ) + yield stream_message_response_ta.dump_json( + jsonrpc_response, by_alias=True + ).decode() # Return SSE response return EventSourceResponse(sse_generator()) diff --git a/uv.lock b/uv.lock index a77fc88..5015736 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -60,6 +60,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "backrefs" version = "5.9" @@ -439,7 +448,10 @@ dependencies = [ { name = "eval-type-backport", marker = "python_full_version < '3.10'" }, { name = "opentelemetry-api" }, { name = "pydantic" }, - { name = "starlette" }, + { name = "sse-starlette", version = "3.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "sse-starlette", version = "3.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "starlette", version = "0.49.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "starlette", version = "0.52.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] [package.optional-dependencies] @@ -452,9 +464,13 @@ dev = [ { name = "asgi-lifespan" }, { name = "coverage" }, { name = "httpx" }, + { name = "httpx-sse" }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pytest-mock" }, { name = "ruff" }, ] docs = [ @@ -470,6 +486,7 @@ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "pydantic", specifier = ">=2.10" }, + { name = "sse-starlette", specifier = ">=2.0.0" }, { name = "starlette", specifier = ">0.29.0" }, ] provides-extras = ["logfire"] @@ -479,9 +496,12 @@ dev = [ { name = "asgi-lifespan" }, { name = "coverage" }, { name = "httpx" }, + { name = "httpx-sse" }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "ruff" }, ] docs = [ @@ -564,6 +584,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1356,6 +1385,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.10'" }, + { name = "pytest", marker = "python_full_version < '3.10'" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, +] + +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version == '3.10.*'" }, + { name = "pytest", marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.10' and python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1514,17 +1589,68 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/9c/0e6afc12c269578be5c0c1c9f4b49a8d32770a080260c333ac04cc1c832d/soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4", size = 36677, upload-time = "2025-04-20T18:50:07.196Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "anyio", marker = "python_full_version < '3.10'" }, + { name = "starlette", version = "0.49.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/5a/a1b3e05da1de8dd17fcbe2d6e66725bff6a8d96293db53cd0d89eac84f81/sse_starlette-3.3.0.tar.gz", hash = "sha256:fdf4a84e2230b12daa3a5a4a1a651586debdefd6eb2fbf812554490d01326896", size = 32679, upload-time = "2026-02-28T08:30:33.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/7a/e3ff499eaee52cf4c78bbf483a3ec536b0b04290e482d5a740cd2729fdcb/sse_starlette-3.3.0-py3-none-any.whl", hash = "sha256:eb5acdac069c7c8b2ce2d3c447b58016da9737ff8a8f475438d88397d49883ef", size = 14276, upload-time = "2026-02-28T08:30:31.771Z" }, +] + +[[package]] +name = "sse-starlette" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "starlette", version = "0.52.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/9f/c3695c2d2d4ef70072c3a06992850498b01c6bc9be531950813716b426fa/sse_starlette-3.3.2.tar.gz", hash = "sha256:678fca55a1945c734d8472a6cad186a55ab02840b4f6786f5ee8770970579dcd", size = 32326, upload-time = "2026-02-28T11:24:34.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/28/8cb142d3fe80c4a2d8af54ca0b003f47ce0ba920974e7990fa6e016402d1/sse_starlette-3.3.2-py3-none-any.whl", hash = "sha256:5c3ea3dad425c601236726af2f27689b74494643f57017cafcb6f8c9acfbb862", size = 14270, upload-time = "2026-02-28T11:24:32.984Z" }, +] + [[package]] name = "starlette" -version = "0.47.1" +version = "0.49.3" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] dependencies = [ - { name = "anyio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "anyio", marker = "python_full_version < '3.10'" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/1a/608df0b10b53b0beb96a37854ee05864d182ddd4b1156a22f1ad3860425a/starlette-0.49.3.tar.gz", hash = "sha256:1c14546f299b5901a1ea0e34410575bc33bbd741377a10484a54445588d00284", size = 2655031, upload-time = "2025-11-01T15:12:26.13Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/e0/021c772d6a662f43b63044ab481dc6ac7592447605b5b35a957785363122/starlette-0.49.3-py3-none-any.whl", hash = "sha256:b579b99715fdc2980cf88c8ec96d3bf1ce16f5a8051a7c2b84ef9b1cdecaea2f", size = 74340, upload-time = "2025-11-01T15:12:24.387Z" }, +] + +[[package]] +name = "starlette" +version = "0.52.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.10' and python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072, upload-time = "2025-06-21T04:03:17.337Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747, upload-time = "2025-06-21T04:03:15.705Z" }, + { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] [[package]] From 23907cfd2f21111202b72bf37a09a4c5b3fa169e Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:12:26 +0100 Subject: [PATCH 05/12] set streaming true by default --- fasta2a/applications.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 9317d79..efa5582 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import json from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from pathlib import Path @@ -19,12 +20,11 @@ AgentCard, AgentProvider, Skill, - StreamMessageResponse, a2a_request_ta, a2a_response_ta, agent_card_ta, + stream_event_ta, stream_message_request_ta, - stream_message_response_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -46,7 +46,7 @@ def __init__( provider: AgentProvider | None = None, skills: list[Skill] | None = None, docs_url: str | None = '/docs', - streaming: bool = False, + streaming: bool = True, # Starlette debug: bool = False, routes: Sequence[Route] | None = None, @@ -145,14 +145,14 @@ async def _agent_run_endpoint(self, request: Request) -> Response: async def sse_generator(): request_id = stream_request.get('id') async for event in self.task_manager.stream_message(stream_request): - jsonrpc_response = StreamMessageResponse( - jsonrpc='2.0', - id=request_id, - result=event, - ) - yield stream_message_response_ta.dump_json( - jsonrpc_response, by_alias=True - ).decode() + # Serialize event to ensure proper camelCase conversion + event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) + + # Wrap in JSON-RPC response + jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict} + + # Convert to JSON string + yield json.dumps(jsonrpc_response) # Return SSE response return EventSourceResponse(sse_generator()) From dec55815677ed71abb64674b0e0ec9b7360f5065 Mon Sep 17 00:00:00 2001 From: Eric Charles <226720+echarles@users.noreply.github.com> Date: Sat, 7 Mar 2026 18:13:41 +0100 Subject: [PATCH 06/12] Update fasta2a/applications.py Co-authored-by: Marcelo Trylesinski --- fasta2a/applications.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fasta2a/applications.py b/fasta2a/applications.py index efa5582..590d435 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -154,7 +154,6 @@ async def sse_generator(): # Convert to JSON string yield json.dumps(jsonrpc_response) - # Return SSE response return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) From 33ee9c510565adf342e1d81236c027452a1b29bf Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:14:50 +0100 Subject: [PATCH 07/12] remove comments --- fasta2a/applications.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 590d435..ba934bc 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -138,20 +138,15 @@ async def _agent_run_endpoint(self, request: Request) -> Response: if a2a_request['method'] == 'message/send': jsonrpc_response = await self.task_manager.send_message(a2a_request) elif a2a_request['method'] == 'message/stream': - # Parse the streaming request stream_request = stream_message_request_ta.validate_json(data) - # Create an async generator wrapper that formats events as JSON-RPC responses async def sse_generator(): request_id = stream_request.get('id') async for event in self.task_manager.stream_message(stream_request): - # Serialize event to ensure proper camelCase conversion event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) - # Wrap in JSON-RPC response jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict} - # Convert to JSON string yield json.dumps(jsonrpc_response) return EventSourceResponse(sse_generator()) From 17096739a73d1cc1fd1471fc26d5ca619e33ada5 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:17:52 +0100 Subject: [PATCH 08/12] revert NotImplementedError --- fasta2a/broker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fasta2a/broker.py b/fasta2a/broker.py index 7ca631f..650c9e2 100644 --- a/fasta2a/broker.py +++ b/fasta2a/broker.py @@ -31,12 +31,12 @@ class Broker(ABC): @abstractmethod async def run_task(self, params: TaskSendParams) -> None: """Send a task to be executed by the worker.""" - ... + raise NotImplementedError('send_run_task is not implemented yet.') @abstractmethod async def cancel_task(self, params: TaskIdParams) -> None: """Cancel a task.""" - ... + raise NotImplementedError('send_cancel_task is not implemented yet.') @abstractmethod async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: From 2db81b8f7fc8c378bbbbe5773d359c991f9ae21b Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:23:50 +0100 Subject: [PATCH 09/12] use storage --- fasta2a/__init__.py | 4 +- fasta2a/applications.py | 17 ++++---- fasta2a/storage.py | 86 +---------------------------------------- fasta2a/worker.py | 63 +++++++++++++++++++++++++++++- tests/test_streaming.py | 52 ++++++++++++------------- 5 files changed, 100 insertions(+), 122 deletions(-) diff --git a/fasta2a/__init__.py b/fasta2a/__init__.py index 57a02f5..59e5b07 100644 --- a/fasta2a/__init__.py +++ b/fasta2a/__init__.py @@ -1,7 +1,7 @@ from .applications import FastA2A from .broker import Broker from .schema import Skill, StreamEvent -from .storage import Storage, StreamingStorageWrapper +from .storage import Storage from .worker import Worker -__all__ = ['FastA2A', 'Skill', 'Storage', 'StreamingStorageWrapper', 'Broker', 'Worker', 'StreamEvent'] +__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker', 'StreamEvent'] diff --git a/fasta2a/applications.py b/fasta2a/applications.py index ba934bc..0dabb51 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -1,6 +1,5 @@ from __future__ import annotations as _annotations -import json from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from pathlib import Path @@ -20,11 +19,12 @@ AgentCard, AgentProvider, Skill, + StreamMessageResponse, a2a_request_ta, a2a_response_ta, agent_card_ta, - stream_event_ta, stream_message_request_ta, + stream_message_response_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -143,11 +143,14 @@ async def _agent_run_endpoint(self, request: Request) -> Response: async def sse_generator(): request_id = stream_request.get('id') async for event in self.task_manager.stream_message(stream_request): - event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) - - jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict} - - yield json.dumps(jsonrpc_response) + jsonrpc_response = StreamMessageResponse( + jsonrpc='2.0', + id=request_id, + result=event, + ) + yield stream_message_response_ta.dump_json( + jsonrpc_response, by_alias=True + ).decode() return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': diff --git a/fasta2a/storage.py b/fasta2a/storage.py index 26ee2f5..cc94d47 100644 --- a/fasta2a/storage.py +++ b/fasta2a/storage.py @@ -5,7 +5,7 @@ import uuid from abc import ABC, abstractmethod from datetime import datetime -from typing import TYPE_CHECKING, Any, Generic +from typing import Any, Generic from typing_extensions import TypeVar @@ -13,15 +13,10 @@ Artifact, Message, Task, - TaskArtifactUpdateEvent, TaskState, TaskStatus, - TaskStatusUpdateEvent, ) -if TYPE_CHECKING: - from .broker import Broker - ContextT = TypeVar('ContextT', default=Any) @@ -140,82 +135,3 @@ async def update_context(self, context_id: str, context: ContextT) -> None: async def load_context(self, context_id: str) -> ContextT | None: """Retrieve the stored context given the `context_id`.""" return self.contexts.get(context_id) - - -class StreamingStorageWrapper(Storage[ContextT]): - """A storage wrapper that publishes streaming events when tasks are updated. - - This wrapper intercepts update_task calls and publishes TaskStatusUpdateEvent - and TaskArtifactUpdateEvent to the broker, enabling SSE streaming without - modifying the underlying storage or worker implementations. - """ - - def __init__(self, storage: Storage[ContextT], broker: Broker): - self._storage = storage - self._broker = broker - - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - return await self._storage.load_task(task_id, history_length) - - async def submit_task(self, context_id: str, message: Message) -> Task: - return await self._storage.submit_task(context_id, message) - - async def update_task( - self, - task_id: str, - state: TaskState, - new_artifacts: list[Artifact] | None = None, - new_messages: list[Message] | None = None, - ) -> Task: - """Update task and publish streaming events.""" - # Update the underlying storage first - task = await self._storage.update_task(task_id, state, new_artifacts, new_messages) - - # Determine if this is a final state - final = state in ('completed', 'failed', 'canceled') - - # For non-final updates, publish status first - if not final: - status_event = TaskStatusUpdateEvent( - kind='status-update', - task_id=task_id, - context_id=task['context_id'], - status=task['status'], - final=False, - ) - await self._broker.send_stream_event(task_id, status_event) - - # Publish message events BEFORE final status (so subscriber receives them) - if new_messages: - for message in new_messages: - await self._broker.send_stream_event(task_id, message) - - # Publish artifact events - if new_artifacts: - for artifact in new_artifacts: - artifact_event = TaskArtifactUpdateEvent( - kind='artifact-update', - task_id=task_id, - context_id=task['context_id'], - artifact=artifact, - ) - await self._broker.send_stream_event(task_id, artifact_event) - - # For final updates, publish status LAST (after messages and artifacts) - if final: - status_event = TaskStatusUpdateEvent( - kind='status-update', - task_id=task_id, - context_id=task['context_id'], - status=task['status'], - final=True, - ) - await self._broker.send_stream_event(task_id, status_event) - - return task - - async def load_context(self, context_id: str) -> ContextT | None: - return await self._storage.load_context(context_id) - - async def update_context(self, context_id: str, context: ContextT) -> None: - await self._storage.update_context(context_id, context) diff --git a/fasta2a/worker.py b/fasta2a/worker.py index bcb0172..3499c45 100644 --- a/fasta2a/worker.py +++ b/fasta2a/worker.py @@ -10,11 +10,12 @@ from opentelemetry.trace import get_tracer, use_span from typing_extensions import assert_never +from .schema import TaskArtifactUpdateEvent, TaskStatusUpdateEvent from .storage import ContextT, Storage if TYPE_CHECKING: from .broker import Broker, TaskOperation - from .schema import Artifact, Message, TaskIdParams, TaskSendParams + from .schema import Artifact, Message, TaskIdParams, TaskSendParams, TaskState tracer = get_tracer(__name__) @@ -56,6 +57,66 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: except Exception: await self.storage.update_task(task_operation['params']['id'], state='failed') + async def update_task( + self, + task_id: str, + state: TaskState, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, + ) -> None: + """Update a task's state in storage and publish streaming events to the broker. + + This is the primary method workers should use to update task state. It handles + both persisting the update and notifying any stream subscribers. + """ + task = await self.storage.update_task(task_id, state, new_artifacts, new_messages) + + final = state in ('completed', 'failed', 'canceled') + + # For non-final updates, publish status first + if not final: + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + kind='status-update', + task_id=task_id, + context_id=task['context_id'], + status=task['status'], + final=False, + ), + ) + + # Publish message events before final status so subscribers receive them + if new_messages: + for message in new_messages: + await self.broker.send_stream_event(task_id, message) + + # Publish artifact events + if new_artifacts: + for artifact in new_artifacts: + await self.broker.send_stream_event( + task_id, + TaskArtifactUpdateEvent( + kind='artifact-update', + task_id=task_id, + context_id=task['context_id'], + artifact=artifact, + ), + ) + + # For final updates, publish status last (after messages and artifacts) + if final: + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + kind='status-update', + task_id=task_id, + context_id=task['context_id'], + status=task['status'], + final=True, + ), + ) + @abstractmethod async def run_task(self, params: TaskSendParams) -> None: ... diff --git a/tests/test_streaming.py b/tests/test_streaming.py index d34a95f..d548d9d 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,7 +1,7 @@ """Tests for the SSE streaming feature in fasta2a. This module tests: -- StreamingStorageWrapper event publishing +- Worker.update_task() streaming event publishing - InMemoryBroker pub/sub for streaming - TaskManager stream_message method - FastA2A message/stream endpoint @@ -34,7 +34,7 @@ TaskStatusUpdateEvent, TextPart, ) -from fasta2a.storage import InMemoryStorage, StreamingStorageWrapper +from fasta2a.storage import InMemoryStorage pytestmark = pytest.mark.anyio @@ -66,7 +66,7 @@ async def run_task(self, params: TaskSendParams) -> None: assert task is not None # Update to working state - await self.storage.update_task(task['id'], state='working') + await self.update_task(task['id'], state='working') # Simulate some work await asyncio.sleep(self.delay) @@ -87,7 +87,7 @@ async def run_task(self, params: TaskSendParams) -> None: ) # Complete the task with message and artifact - await self.storage.update_task( + await self.update_task( task['id'], state='completed', new_messages=[message], @@ -95,7 +95,7 @@ async def run_task(self, params: TaskSendParams) -> None: ) async def cancel_task(self, params: TaskIdParams) -> None: - await self.storage.update_task(params['id'], state='canceled') + await self.update_task(params['id'], state='canceled') def build_message_history(self, history: list[Message]) -> list[Any]: return history @@ -108,12 +108,11 @@ def build_artifacts(self, result: Any) -> list[Artifact]: async def create_streaming_app(response_text: str = 'Hello!'): """Create a FastA2A app with streaming enabled.""" broker = InMemoryBroker() - base_storage = InMemoryStorage() - streaming_storage = StreamingStorageWrapper(base_storage, broker) + storage = InMemoryStorage() worker = EchoWorker( broker=broker, - storage=streaming_storage, + storage=storage, response_text=response_text, ) @@ -124,7 +123,7 @@ async def lifespan(app: FastA2A): yield app = FastA2A( - storage=streaming_storage, + storage=storage, broker=broker, streaming=True, lifespan=lifespan, @@ -264,17 +263,17 @@ async def test_no_subscribers_doesnt_error(self): ) -# Tests for StreamingStorageWrapper +# Tests for Worker.update_task() streaming events -class TestStreamingStorageWrapper: - """Tests for StreamingStorageWrapper event publishing.""" +class TestWorkerUpdateTask: + """Tests for Worker.update_task() event publishing.""" async def test_publishes_status_update_on_working(self): """Test that updating to 'working' publishes a status-update event.""" broker = InMemoryBroker() - base_storage = InMemoryStorage() - storage = StreamingStorageWrapper(base_storage, broker) + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) async with broker: # Create a task @@ -298,8 +297,8 @@ async def subscriber(): sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.01) - # Update to working - await storage.update_task(task_id, state='working') + # Update to working via worker + await worker.update_task(task_id, state='working') await asyncio.wait_for(sub_task, timeout=1.0) @@ -311,8 +310,8 @@ async def subscriber(): async def test_publishes_message_before_final_status(self): """Test that messages are published before the final status update.""" broker = InMemoryBroker() - base_storage = InMemoryStorage() - storage = StreamingStorageWrapper(base_storage, broker) + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) async with broker: message = Message( @@ -335,14 +334,14 @@ async def subscriber(): sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.01) - # Complete with a new message + # Complete with a new message via worker agent_message = Message( role='agent', parts=[TextPart(text='Hi there!', kind='text')], kind='message', message_id='msg-2', ) - await storage.update_task(task_id, state='completed', new_messages=[agent_message]) + await worker.update_task(task_id, state='completed', new_messages=[agent_message]) await asyncio.wait_for(sub_task, timeout=1.0) @@ -356,8 +355,8 @@ async def subscriber(): async def test_publishes_artifacts(self): """Test that artifacts are published.""" broker = InMemoryBroker() - base_storage = InMemoryStorage() - storage = StreamingStorageWrapper(base_storage, broker) + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) async with broker: message = Message( @@ -385,7 +384,7 @@ async def subscriber(): name='result', parts=[TextPart(text='Result data', kind='text')], ) - await storage.update_task(task_id, state='completed', new_artifacts=[artifact]) + await worker.update_task(task_id, state='completed', new_artifacts=[artifact]) await asyncio.wait_for(sub_task, timeout=1.0) @@ -535,17 +534,16 @@ class TestTaskManagerStreamMessage: async def test_stream_message_yields_events_in_order(self): """Test that stream_message yields events: task, status updates, messages, final status.""" broker = InMemoryBroker() - base_storage = InMemoryStorage() - streaming_storage = StreamingStorageWrapper(base_storage, broker) + storage = InMemoryStorage() # Use a longer delay to ensure we capture the initial task before it completes - worker = EchoWorker(broker=broker, storage=streaming_storage, delay=0.2) + worker = EchoWorker(broker=broker, storage=storage, delay=0.2) async with broker: async with worker.run(): from fasta2a.task_manager import TaskManager - task_manager = TaskManager(broker=broker, storage=streaming_storage) + task_manager = TaskManager(broker=broker, storage=storage) async with task_manager: request: StreamMessageRequest = { 'jsonrpc': '2.0', From ab97e621fa8e8f41860e775327de7af5c6708af1 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:39:34 +0100 Subject: [PATCH 10/12] no test class --- tests/test_streaming.py | 811 ++++++++++++++++++++-------------------- 1 file changed, 403 insertions(+), 408 deletions(-) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index d548d9d..baf5eea 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -135,456 +135,451 @@ async def lifespan(app: FastA2A): # Tests for InMemoryBroker streaming -class TestInMemoryBrokerStreaming: - """Tests for InMemoryBroker pub/sub streaming.""" - - async def test_subscribe_and_receive_events(self): - """Test that subscribers receive events sent by send_stream_event.""" - broker = InMemoryBroker() - - async with broker: - task_id = 'test-task-123' - received_events: list[StreamEvent] = [] - - # Start subscriber in background - async def subscriber(): - async for event in broker.subscribe_to_stream(task_id): - received_events.append(event) - # Check if event has a 'final' field (for status events) - if isinstance(event, dict) and event.get('final', False): - break +async def test_broker_subscribe_and_receive_events(): + """Test that subscribers receive events sent by send_stream_event.""" + broker = InMemoryBroker() + + async with broker: + task_id = 'test-task-123' + received_events: list[StreamEvent] = [] + + # Start subscriber in background + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + # Check if event has a 'final' field (for status events) + if isinstance(event, dict) and event.get('final', False): + break + + subscriber_task = asyncio.create_task(subscriber()) + + # Give subscriber time to register + await asyncio.sleep(0.01) + + # Send events + status_event_1 = TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ) + message_event = Message( + message_id='test-msg-1', + role='agent', + parts=[TextPart(kind='text', text='Hello')], + kind='message', + ) + status_event_2 = TaskStatusUpdateEvent( + task_id=task_id, + context_id='test-context', + kind='status-update', + status=TaskStatus(state='completed'), + final=True, + ) + await broker.send_stream_event(task_id, status_event_1) + await broker.send_stream_event(task_id, message_event) + await broker.send_stream_event(task_id, status_event_2) - subscriber_task = asyncio.create_task(subscriber()) + # Wait for subscriber to finish + await asyncio.wait_for(subscriber_task, timeout=1.0) - # Give subscriber time to register - await asyncio.sleep(0.01) + assert len(received_events) == 3 + assert received_events[0]['kind'] == 'status-update' + assert received_events[1]['kind'] == 'message' # Message has both kind and role + # Only check 'final' on status update events + assert received_events[2]['kind'] == 'status-update' and received_events[2]['final'] is True - # Send events - status_event_1 = TaskStatusUpdateEvent( + +async def test_broker_multiple_subscribers(): + """Test that multiple subscribers receive the same events.""" + broker = InMemoryBroker() + + async with broker: + task_id = 'test-task-456' + received_1: list[StreamEvent] = [] + received_2: list[StreamEvent] = [] + + async def subscriber1(): + async for event in broker.subscribe_to_stream(task_id): + received_1.append(event) + if isinstance(event, dict) and event.get('final', False): + break + + async def subscriber2(): + async for event in broker.subscribe_to_stream(task_id): + received_2.append(event) + if isinstance(event, dict) and event.get('final', False): + break + + task1 = asyncio.create_task(subscriber1()) + task2 = asyncio.create_task(subscriber2()) + + await asyncio.sleep(0.01) + + await broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( task_id=task_id, context_id='test-context', kind='status-update', status=TaskStatus(state='working'), final=False, - ) - message_event = Message( - message_id='test-msg-1', - role='agent', - parts=[TextPart(kind='text', text='Hello')], - kind='message', - ) - status_event_2 = TaskStatusUpdateEvent( + ), + ) + await broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( task_id=task_id, context_id='test-context', kind='status-update', status=TaskStatus(state='completed'), final=True, - ) - await broker.send_stream_event(task_id, status_event_1) - await broker.send_stream_event(task_id, message_event) - await broker.send_stream_event(task_id, status_event_2) - - # Wait for subscriber to finish - await asyncio.wait_for(subscriber_task, timeout=1.0) - - assert len(received_events) == 3 - assert received_events[0]['kind'] == 'status-update' - assert received_events[1]['kind'] == 'message' # Message has both kind and role - # Only check 'final' on status update events - assert received_events[2]['kind'] == 'status-update' and received_events[2]['final'] is True - - async def test_multiple_subscribers(self): - """Test that multiple subscribers receive the same events.""" - broker = InMemoryBroker() - - async with broker: - task_id = 'test-task-456' - received_1: list[StreamEvent] = [] - received_2: list[StreamEvent] = [] - - async def subscriber1(): - async for event in broker.subscribe_to_stream(task_id): - received_1.append(event) - if isinstance(event, dict) and event.get('final', False): - break + ), + ) - async def subscriber2(): - async for event in broker.subscribe_to_stream(task_id): - received_2.append(event) - if isinstance(event, dict) and event.get('final', False): - break + await asyncio.wait_for(asyncio.gather(task1, task2), timeout=1.0) + + assert len(received_1) == 2 + assert len(received_2) == 2 - task1 = asyncio.create_task(subscriber1()) - task2 = asyncio.create_task(subscriber2()) - - await asyncio.sleep(0.01) - - await broker.send_stream_event( - task_id, - TaskStatusUpdateEvent( - task_id=task_id, - context_id='test-context', - kind='status-update', - status=TaskStatus(state='working'), - final=False, - ), - ) - await broker.send_stream_event( - task_id, - TaskStatusUpdateEvent( - task_id=task_id, - context_id='test-context', - kind='status-update', - status=TaskStatus(state='completed'), - final=True, - ), - ) - - await asyncio.wait_for(asyncio.gather(task1, task2), timeout=1.0) - - assert len(received_1) == 2 - assert len(received_2) == 2 - - async def test_no_subscribers_doesnt_error(self): - """Test that sending events with no subscribers doesn't raise errors.""" - broker = InMemoryBroker() - - async with broker: - # Should not raise - await broker.send_stream_event( - 'nonexistent-task', - TaskStatusUpdateEvent( - task_id='nonexistent-task', - context_id='test-context', - kind='status-update', - status=TaskStatus(state='working'), - final=False, - ), - ) + +async def test_broker_no_subscribers_doesnt_error(): + """Test that sending events with no subscribers doesn't raise errors.""" + broker = InMemoryBroker() + + async with broker: + # Should not raise + await broker.send_stream_event( + 'nonexistent-task', + TaskStatusUpdateEvent( + task_id='nonexistent-task', + context_id='test-context', + kind='status-update', + status=TaskStatus(state='working'), + final=False, + ), + ) # Tests for Worker.update_task() streaming events -class TestWorkerUpdateTask: - """Tests for Worker.update_task() event publishing.""" - - async def test_publishes_status_update_on_working(self): - """Test that updating to 'working' publishes a status-update event.""" - broker = InMemoryBroker() - storage = InMemoryStorage() - worker = EchoWorker(broker=broker, storage=storage) - - async with broker: - # Create a task - message = Message( - role='user', - parts=[TextPart(text='Hello', kind='text')], - kind='message', - message_id='msg-1', - ) - task = await storage.submit_task('ctx-1', message) - task_id = task['id'] - - received_events: list[StreamEvent] = [] - - async def subscriber(): - async for event in broker.subscribe_to_stream(task_id): - received_events.append(event) - # Stop after first event for this test +async def test_worker_publishes_status_update_on_working(): + """Test that updating to 'working' publishes a status-update event.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + + async with broker: + # Create a task + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events: list[StreamEvent] = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + # Stop after first event for this test + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) + + # Update to working via worker + await worker.update_task(task_id, state='working') + + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received_events) == 1 + assert received_events[0]['kind'] == 'status-update' + assert received_events[0]['status']['state'] == 'working' + assert received_events[0]['final'] is False + + +async def test_worker_publishes_message_before_final_status(): + """Test that messages are published before the final status update.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + + async with broker: + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] + + received_events: list[StreamEvent] = [] + + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if isinstance(event, dict) and event.get('final', False): break - sub_task = asyncio.create_task(subscriber()) - await asyncio.sleep(0.01) - - # Update to working via worker - await worker.update_task(task_id, state='working') - - await asyncio.wait_for(sub_task, timeout=1.0) - - assert len(received_events) == 1 - assert received_events[0]['kind'] == 'status-update' - assert received_events[0]['status']['state'] == 'working' - assert received_events[0]['final'] is False - - async def test_publishes_message_before_final_status(self): - """Test that messages are published before the final status update.""" - broker = InMemoryBroker() - storage = InMemoryStorage() - worker = EchoWorker(broker=broker, storage=storage) - - async with broker: - message = Message( - role='user', - parts=[TextPart(text='Hello', kind='text')], - kind='message', - message_id='msg-1', - ) - task = await storage.submit_task('ctx-1', message) - task_id = task['id'] - - received_events: list[StreamEvent] = [] - - async def subscriber(): - async for event in broker.subscribe_to_stream(task_id): - received_events.append(event) - if isinstance(event, dict) and event.get('final', False): - break + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) - sub_task = asyncio.create_task(subscriber()) - await asyncio.sleep(0.01) - - # Complete with a new message via worker - agent_message = Message( - role='agent', - parts=[TextPart(text='Hi there!', kind='text')], - kind='message', - message_id='msg-2', - ) - await worker.update_task(task_id, state='completed', new_messages=[agent_message]) - - await asyncio.wait_for(sub_task, timeout=1.0) - - # Should have: message, then final status - assert len(received_events) == 2 - assert received_events[0]['kind'] == 'message' - assert received_events[0]['role'] == 'agent' - assert received_events[1]['kind'] == 'status-update' - assert received_events[1]['final'] is True - - async def test_publishes_artifacts(self): - """Test that artifacts are published.""" - broker = InMemoryBroker() - storage = InMemoryStorage() - worker = EchoWorker(broker=broker, storage=storage) - - async with broker: - message = Message( - role='user', - parts=[TextPart(text='Hello', kind='text')], - kind='message', - message_id='msg-1', - ) - task = await storage.submit_task('ctx-1', message) - task_id = task['id'] - - received_events: list[StreamEvent] = [] - - async def subscriber(): - async for event in broker.subscribe_to_stream(task_id): - received_events.append(event) - if isinstance(event, dict) and event.get('final', False): - break + # Complete with a new message via worker + agent_message = Message( + role='agent', + parts=[TextPart(text='Hi there!', kind='text')], + kind='message', + message_id='msg-2', + ) + await worker.update_task(task_id, state='completed', new_messages=[agent_message]) - sub_task = asyncio.create_task(subscriber()) - await asyncio.sleep(0.01) + await asyncio.wait_for(sub_task, timeout=1.0) - artifact = Artifact( - artifact_id='art-1', - name='result', - parts=[TextPart(text='Result data', kind='text')], - ) - await worker.update_task(task_id, state='completed', new_artifacts=[artifact]) + # Should have: message, then final status + assert len(received_events) == 2 + assert received_events[0]['kind'] == 'message' + assert received_events[0]['role'] == 'agent' + assert received_events[1]['kind'] == 'status-update' + assert received_events[1]['final'] is True - await asyncio.wait_for(sub_task, timeout=1.0) - # Should have: artifact, then final status - assert len(received_events) == 2 - assert received_events[0]['kind'] == 'artifact-update' - artifact = received_events[0]['artifact'] - assert artifact.get('name') == 'result' # Use .get() for NotRequired fields - assert received_events[1]['kind'] == 'status-update' - assert received_events[1]['final'] is True +async def test_worker_publishes_artifacts(): + """Test that artifacts are published.""" + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + async with broker: + message = Message( + role='user', + parts=[TextPart(text='Hello', kind='text')], + kind='message', + message_id='msg-1', + ) + task = await storage.submit_task('ctx-1', message) + task_id = task['id'] -# Tests for FastA2A streaming endpoint + received_events: list[StreamEvent] = [] + async def subscriber(): + async for event in broker.subscribe_to_stream(task_id): + received_events.append(event) + if isinstance(event, dict) and event.get('final', False): + break -class TestFastA2AStreaming: - """Tests for the FastA2A message/stream SSE endpoint.""" + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.01) - async def test_agent_card_shows_streaming_enabled(self): - """Test that agent card reflects streaming capability.""" - async with create_streaming_app() as app: - async with create_test_client(app) as client: - response = await client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - data = response.json() - assert data['capabilities']['streaming'] is True - - async def test_message_stream_returns_sse(self): - """Test that message/stream returns SSE response.""" - async with create_streaming_app(response_text='Test response') as app: - async with create_test_client(app) as client: - payload = { - 'jsonrpc': '2.0', - 'id': 'test-request-1', - 'method': 'message/stream', - 'params': { - 'message': { - 'role': 'user', - 'parts': [{'kind': 'text', 'text': 'Hello'}], - 'kind': 'message', - 'messageId': 'user-msg-1', - } - }, - } + artifact = Artifact( + artifact_id='art-1', + name='result', + parts=[TextPart(text='Result data', kind='text')], + ) + await worker.update_task(task_id, state='completed', new_artifacts=[artifact]) - # Use streaming request - async with client.stream('POST', '/', json=payload) as response: - assert response.status_code == 200 - assert 'text/event-stream' in response.headers.get('content-type', '') - - events: list[Any] = [] - async for line in response.aiter_lines(): - if line.startswith('data: '): - data = json.loads(line[6:]) - events.append(data) - - # Should have multiple events - assert len(events) >= 3 # task, working status, completed status - - # First event should be the task - assert events[0]['result']['kind'] == 'task' - assert events[0]['result']['status']['state'] == 'submitted' - - # Should have a working status - working_events = [ - e - for e in events - if e['result'].get('kind') == 'status-update' - and e['result'].get('status', {}).get('state') == 'working' - ] - assert len(working_events) >= 1 - - # Should have a message with the response - message_events = [ - e for e in events if e['result'].get('kind') == 'message' and e['result'].get('role') == 'agent' - ] - assert len(message_events) >= 1 - assert 'Test response' in message_events[0]['result']['parts'][0]['text'] - - # Last event should be final status - final_events = [e for e in events if e['result'].get('final') is True] - assert len(final_events) >= 1 - - async def test_message_stream_includes_context_id(self): - """Test that streamed events include context_id.""" - async with create_streaming_app() as app: - async with create_test_client(app) as client: - context_id = str(uuid.uuid4()) - payload = { - 'jsonrpc': '2.0', - 'id': 'test-2', - 'method': 'message/stream', - 'params': { - 'message': { - 'role': 'user', - 'parts': [{'kind': 'text', 'text': 'Hi'}], - 'kind': 'message', - 'messageId': 'msg-1', - 'contextId': context_id, - } - }, - } + await asyncio.wait_for(sub_task, timeout=1.0) - async with client.stream('POST', '/', json=payload) as response: - events: list[Any] = [] - async for line in response.aiter_lines(): - if line.startswith('data: '): - events.append(json.loads(line[6:])) - - # All events should have the same context_id - for event in events: - result = event['result'] - event_context = result.get('contextId') or result.get('context_id') - assert event_context == context_id - - async def test_message_send_still_works(self): - """Test that non-streaming message/send still works.""" - async with create_streaming_app() as app: - async with create_test_client(app) as client: - payload = { - 'jsonrpc': '2.0', - 'id': 'test-3', - 'method': 'message/send', - 'params': { - 'message': { - 'role': 'user', - 'parts': [{'kind': 'text', 'text': 'Hello'}], - 'kind': 'message', - 'messageId': 'msg-1', - } - }, - } + # Should have: artifact, then final status + assert len(received_events) == 2 + assert received_events[0]['kind'] == 'artifact-update' + artifact = received_events[0]['artifact'] + assert artifact.get('name') == 'result' # Use .get() for NotRequired fields + assert received_events[1]['kind'] == 'status-update' + assert received_events[1]['final'] is True + + +# Tests for FastA2A streaming endpoint - response = await client.post('/', json=payload) + +async def test_agent_card_shows_streaming_enabled(): + """Test that agent card reflects streaming capability.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + response = await client.get('/.well-known/agent-card.json') + assert response.status_code == 200 + data = response.json() + assert data['capabilities']['streaming'] is True + + +async def test_message_stream_returns_sse(): + """Test that message/stream returns SSE response.""" + async with create_streaming_app(response_text='Test response') as app: + async with create_test_client(app) as client: + payload = { + 'jsonrpc': '2.0', + 'id': 'test-request-1', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'kind': 'message', + 'messageId': 'user-msg-1', + } + }, + } + + # Use streaming request + async with client.stream('POST', '/', json=payload) as response: assert response.status_code == 200 - data = response.json() - assert data['result']['kind'] == 'task' - assert data['result']['status']['state'] == 'submitted' + assert 'text/event-stream' in response.headers.get('content-type', '') + + events: list[Any] = [] + async for line in response.aiter_lines(): + if line.startswith('data: '): + data = json.loads(line[6:]) + events.append(data) + + # Should have multiple events + assert len(events) >= 3 # task, working status, completed status + + # First event should be the task + assert events[0]['result']['kind'] == 'task' + assert events[0]['result']['status']['state'] == 'submitted' + + # Should have a working status + working_events = [ + e + for e in events + if e['result'].get('kind') == 'status-update' + and e['result'].get('status', {}).get('state') == 'working' + ] + assert len(working_events) >= 1 + + # Should have a message with the response + message_events = [ + e for e in events if e['result'].get('kind') == 'message' and e['result'].get('role') == 'agent' + ] + assert len(message_events) >= 1 + assert 'Test response' in message_events[0]['result']['parts'][0]['text'] + + # Last event should be final status + final_events = [e for e in events if e['result'].get('final') is True] + assert len(final_events) >= 1 + + +async def test_message_stream_includes_context_id(): + """Test that streamed events include context_id.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + context_id = str(uuid.uuid4()) + payload = { + 'jsonrpc': '2.0', + 'id': 'test-2', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hi'}], + 'kind': 'message', + 'messageId': 'msg-1', + 'contextId': context_id, + } + }, + } + + async with client.stream('POST', '/', json=payload) as response: + events: list[Any] = [] + async for line in response.aiter_lines(): + if line.startswith('data: '): + events.append(json.loads(line[6:])) + + # All events should have the same context_id + for event in events: + result = event['result'] + event_context = result.get('contextId') or result.get('context_id') + assert event_context == context_id + + +async def test_message_send_still_works(): + """Test that non-streaming message/send still works.""" + async with create_streaming_app() as app: + async with create_test_client(app) as client: + payload = { + 'jsonrpc': '2.0', + 'id': 'test-3', + 'method': 'message/send', + 'params': { + 'message': { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'kind': 'message', + 'messageId': 'msg-1', + } + }, + } + + response = await client.post('/', json=payload) + assert response.status_code == 200 + data = response.json() + assert data['result']['kind'] == 'task' + assert data['result']['status']['state'] == 'submitted' # Tests for TaskManager stream_message -class TestTaskManagerStreamMessage: - """Tests for TaskManager.stream_message method.""" +async def test_stream_message_yields_events_in_order(): + """Test that stream_message yields events: task, status updates, messages, final status.""" + broker = InMemoryBroker() + storage = InMemoryStorage() - async def test_stream_message_yields_events_in_order(self): - """Test that stream_message yields events: task, status updates, messages, final status.""" - broker = InMemoryBroker() - storage = InMemoryStorage() + # Use a longer delay to ensure we capture the initial task before it completes + worker = EchoWorker(broker=broker, storage=storage, delay=0.2) - # Use a longer delay to ensure we capture the initial task before it completes - worker = EchoWorker(broker=broker, storage=storage, delay=0.2) + async with broker: + async with worker.run(): + from fasta2a.task_manager import TaskManager - async with broker: - async with worker.run(): - from fasta2a.task_manager import TaskManager - - task_manager = TaskManager(broker=broker, storage=storage) - async with task_manager: - request: StreamMessageRequest = { - 'jsonrpc': '2.0', - 'id': 'req-1', - 'method': 'message/stream', - 'params': MessageSendParams( - message=Message( - role='user', - parts=[TextPart(kind='text', text='Test')], - kind='message', - message_id='msg-1', - ) - ), - } + task_manager = TaskManager(broker=broker, storage=storage) + async with task_manager: + request: StreamMessageRequest = { + 'jsonrpc': '2.0', + 'id': 'req-1', + 'method': 'message/stream', + 'params': MessageSendParams( + message=Message( + role='user', + parts=[TextPart(kind='text', text='Test')], + kind='message', + message_id='msg-1', + ) + ), + } + + events: list[StreamEvent] = [] + async for event in task_manager.stream_message(request): + events.append(event) + # Stop when we get the final status + if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final'): + break + + # Should have at least: task, working status, message, final status + assert len(events) >= 4 + + # First event should be the task (submitted state) + assert events[0]['kind'] == 'task' + + # Should have a working status update + working_events = [ + e + for e in events + if e.get('kind') == 'status-update' and e.get('status', {}).get('state') == 'working' + ] + assert len(working_events) >= 1 + + # Should have an agent message + message_events = [e for e in events if e.get('kind') == 'message' and e.get('role') == 'agent'] + assert len(message_events) >= 1 - events: list[StreamEvent] = [] - async for event in task_manager.stream_message(request): - events.append(event) - # Stop when we get the final status - if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final'): - break - - # Should have at least: task, working status, message, final status - assert len(events) >= 4 - - # First event should be the task (submitted state) - assert events[0]['kind'] == 'task' - - # Should have a working status update - working_events = [ - e - for e in events - if e.get('kind') == 'status-update' and e.get('status', {}).get('state') == 'working' - ] - assert len(working_events) >= 1 - - # Should have an agent message - message_events = [e for e in events if e.get('kind') == 'message' and e.get('role') == 'agent'] - assert len(message_events) >= 1 - - # Last event should be final status - assert events[-1]['kind'] == 'status-update' - assert events[-1]['final'] is True - assert events[-1]['status']['state'] == 'completed' + # Last event should be final status + assert events[-1]['kind'] == 'status-update' + assert events[-1]['final'] is True + assert events[-1]['status']['state'] == 'completed' From e5e2be8a2f6ceb1c8743374d069499f36b64e8b3 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:45:39 +0100 Subject: [PATCH 11/12] fix: ci --- fasta2a/applications.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 0dabb51..18c7256 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -148,9 +148,7 @@ async def sse_generator(): id=request_id, result=event, ) - yield stream_message_response_ta.dump_json( - jsonrpc_response, by_alias=True - ).decode() + yield stream_message_response_ta.dump_json(jsonrpc_response, by_alias=True).decode() return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': From 447153ff7761b097d628806925fd0719e8f5d1a1 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Sat, 7 Mar 2026 18:47:15 +0100 Subject: [PATCH 12/12] fix: ci --- tests/test_applications.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_applications.py b/tests/test_applications.py index 21692e0..1bc2798 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -38,7 +38,7 @@ async def test_agent_card(): 'defaultInputModes': ['application/json'], 'defaultOutputModes': ['application/json'], 'capabilities': { - 'streaming': False, + 'streaming': True, 'pushNotifications': False, 'stateTransitionHistory': False, },