From f902343ab96da7e517498d161600831d7a089897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Wed, 25 Mar 2026 16:23:24 +0100 Subject: [PATCH 1/4] feat(sdk): add streaming extensions, response accumulator, and UI streaming support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a delta-based streaming protocol using JSON Patch (RFC 6902) with a custom `str_ins` operation for efficient token-by-token text delivery. Server (adk-py): - MessageAccumulator: 3-level state machine collecting string chunks, parts, and metadata into complete messages with incremental patches - Extended JSON Patch (jsonpatch_ext): str_ins operation for optimized text streaming without full-value replacement - StreamingExtensionServer/Client: extension negotiation, patch extraction, delta emission, and graceful fallback for non-streaming clients - Refactored extension base: activation lifecycle, default demand constant consolidation (DEFAULT_DEMAND_NAME), and metadata handling - RunContext: extracted shared store logic into _prepare_store_data, added store_sync for synchronous callers - New yield types: ArtifactChunk for chunked artifact streaming Client (adk-ts / adk-ui): - TypeScript streaming extension types and Zod schemas - UI patch application (streaming.ts): JSON pointer resolution, replace/add/str_ins ops with optimized cloning (skip structuredClone for primitives) - Client streaming integration: draft accumulation, patch extraction from status update metadata, replace-mode emission - Fix: text duplication on final message — detect active streaming session and emit with replace flag to prevent appending duplicate content - Batched React state updates in AgentRunProvider to reduce re-renders during streaming - Removed dead commented-out proxy header code from adk-client Tests: - Unit tests for MessageAccumulator state machine - Unit tests for extended JSON Patch operations - Unit tests for RunContext store methods - E2E streaming tests: patch application, message reconstruction, client fallback behavior - Updated yield tests for new artifact chunking Assisted-By: Claude (Anthropic AI) Signed-off-by: Radek Ježek --- agents/chat/uv.lock | 25 + apps/adk-cli/src/kagenti_cli/api.py | 17 +- .../adk-cli/src/kagenti_cli/commands/agent.py | 273 +++++----- apps/adk-cli/uv.lock | 25 + apps/adk-py/pyproject.toml | 2 + .../a2a/extensions/auth/oauth/oauth.py | 24 +- .../a2a/extensions/auth/secrets/secrets.py | 17 +- .../src/kagenti_adk/a2a/extensions/base.py | 62 ++- .../a2a/extensions/interactions/approval.py | 8 +- .../a2a/extensions/services/embedding.py | 28 +- .../a2a/extensions/services/form.py | 2 +- .../a2a/extensions/services/llm.py | 26 +- .../a2a/extensions/services/mcp.py | 23 +- .../a2a/extensions/services/platform.py | 6 +- .../kagenti_adk/a2a/extensions/streaming.py | 335 ++++++++++++ .../kagenti_adk/a2a/extensions/tools/call.py | 8 +- .../kagenti_adk/a2a/extensions/ui/__init__.py | 117 ++++- .../a2a/extensions/ui/agent_detail.py | 2 +- .../kagenti_adk/a2a/extensions/ui/canvas.py | 2 +- .../kagenti_adk/a2a/extensions/ui/citation.py | 6 +- .../kagenti_adk/a2a/extensions/ui/error.py | 2 +- .../a2a/extensions/ui/form_request.py | 2 +- .../kagenti_adk/a2a/extensions/ui/settings.py | 2 +- .../a2a/extensions/ui/trajectory.py | 8 +- .../src/kagenti_adk/server/accumulator.py | 219 ++++++++ apps/adk-py/src/kagenti_adk/server/agent.py | 277 +++++----- .../src/kagenti_adk/server/constants.py | 2 + apps/adk-py/src/kagenti_adk/server/context.py | 37 +- .../src/kagenti_adk/server/dependencies.py | 38 +- .../src/kagenti_adk/server/exceptions.py | 7 +- .../src/kagenti_adk/server/jsonpatch_ext.py | 160 ++++++ apps/adk-py/src/kagenti_adk/server/utils.py | 39 +- apps/adk-py/src/kagenti_adk/types.py | 19 +- apps/adk-py/tests/conftest.py | 9 - apps/adk-py/tests/e2e/conftest.py | 56 ++- apps/adk-py/tests/e2e/test_streaming.py | 476 ++++++++++++++++++ apps/adk-py/tests/e2e/test_yields.py | 135 +++-- apps/adk-py/tests/test_merge_utils.py | 60 +++ .../tests/unit/server/test_accumulator.py | 292 +++++++++++ apps/adk-py/tests/unit/server/test_context.py | 188 +++++++ .../tests/unit/server/test_jsonpatch_ext.py | 78 +++ .../unit/test_agent_detail_population.py | 5 + apps/adk-py/tests/unit/test_dependencies.py | 19 +- apps/adk-py/uv.lock | 25 + apps/adk-server/pyproject.toml | 1 + .../src/adk_server/api/routes/a2a.py | 2 +- .../e2e/agents/test_platform_extensions.py | 17 +- apps/adk-server/uv.lock | 27 + .../adk-ts/src/client/a2a/extensions/index.ts | 1 + .../adk-ts/src/client/a2a/extensions/types.ts | 1 + .../a2a/extensions/ui/citation/index.ts | 8 +- .../a2a/extensions/ui/streaming/index.ts | 20 + .../a2a/extensions/ui/streaming/schemas.ts | 18 + .../a2a/extensions/ui/streaming/types.ts | 12 + .../a2a/extensions/ui/trajectory/index.ts | 4 +- apps/adk-ui/src/api/a2a/agent-card.ts | 3 +- apps/adk-ui/src/api/a2a/client.ts | 33 +- apps/adk-ui/src/api/a2a/jsonrpc-client.ts | 12 +- apps/adk-ui/src/api/a2a/part-processors.ts | 8 +- apps/adk-ui/src/api/a2a/streaming.ts | 136 +++++ apps/adk-ui/src/api/a2a/types.ts | 6 +- apps/adk-ui/src/api/a2a/utils.ts | 2 + apps/adk-ui/src/api/adk-client.ts | 15 - apps/adk-ui/src/app/(auth)/auth.ts | 2 +- .../src/modules/platform-context/constants.ts | 2 +- .../contexts/agent-run/AgentRunProvider.tsx | 17 +- .../a2a/extensions/ui/streaming/index.ts | 20 + .../a2a/extensions/ui/streaming/schemas.ts | 18 + .../a2a/extensions/ui/streaming/types.ts | 12 + 69 files changed, 3058 insertions(+), 502 deletions(-) create mode 100644 apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py create mode 100644 apps/adk-py/src/kagenti_adk/server/accumulator.py create mode 100644 apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py delete mode 100644 apps/adk-py/tests/conftest.py create mode 100644 apps/adk-py/tests/e2e/test_streaming.py create mode 100644 apps/adk-py/tests/test_merge_utils.py create mode 100644 apps/adk-py/tests/unit/server/test_accumulator.py create mode 100644 apps/adk-py/tests/unit/server/test_context.py create mode 100644 apps/adk-py/tests/unit/server/test_jsonpatch_ext.py create mode 100644 apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts create mode 100644 apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts create mode 100644 apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts create mode 100644 apps/adk-ui/src/api/a2a/streaming.ts create mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts create mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts create mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts diff --git a/agents/chat/uv.lock b/agents/chat/uv.lock index fea0a6da..82d4a46a 100644 --- a/agents/chat/uv.lock +++ b/agents/chat/uv.lock @@ -889,6 +889,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/c7/af399a2e7a67fd18d63c40c5e62d3af4e67b836a2107468b6a5ea24c4304/jsonpointer-3.1.1.tar.gz", hash = "sha256:0b801c7db33a904024f6004d526dcc53bbb8a4a0f4e32bfd10beadf60adf1900", size = 9068, upload-time = "2026-03-23T22:32:32.458Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/6a/a83720e953b1682d2d109d3c2dbb0bc9bf28cc1cbc205be4ef4be5da709d/jsonpointer-3.1.1-py3-none-any.whl", hash = "sha256:8ff8b95779d071ba472cf5bc913028df06031797532f08a7d5b602d8b2a488ca", size = 7659, upload-time = "2026-03-23T22:32:31.568Z" }, +] + [[package]] name = "jsonref" version = "1.1.0" @@ -932,6 +953,7 @@ source = { editable = "../../apps/adk-py" } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -939,6 +961,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -958,6 +981,7 @@ dependencies = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -965,6 +989,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-cli/src/kagenti_cli/api.py b/apps/adk-cli/src/kagenti_cli/api.py index ba62c71c..bf16d964 100644 --- a/apps/adk-cli/src/kagenti_cli/api.py +++ b/apps/adk-cli/src/kagenti_cli/api.py @@ -16,7 +16,8 @@ import httpx import openai import pydantic -from a2a.client import A2AClientError, Client, ClientConfig, ClientFactory +from a2a.client import A2AClientError, Client, ClientCallContext, ClientConfig, ClientFactory +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import AgentCard from kagenti_adk.platform.context import ContextToken from google.protobuf.json_format import MessageToDict @@ -131,8 +132,18 @@ class OpenAPISchema(pydantic.BaseModel): return None +def make_extension_context(extensions: list[str] | None = None) -> ClientCallContext | None: + """Create a ClientCallContext with extension URIs as service parameters.""" + if not extensions: + return None + return ClientCallContext(service_parameters={HTTP_EXTENSION_HEADER: ",".join(extensions)}) + + @asynccontextmanager -async def a2a_client(agent_card: AgentCard, context_token: ContextToken) -> AsyncIterator[Client]: +async def a2a_client( + agent_card: AgentCard, + context_token: ContextToken, +) -> AsyncIterator[Client]: try: async with httpx.AsyncClient( headers={"Authorization": f"Bearer {context_token.token.get_secret_value()}"}, @@ -140,7 +151,7 @@ async def a2a_client(agent_card: AgentCard, context_token: ContextToken) -> Asyn timeout=timedelta(hours=1).total_seconds(), ) as httpx_client: yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create( - card=agent_card + card=agent_card, ) except A2AClientError as ex: card_data = json.dumps( diff --git a/apps/adk-cli/src/kagenti_cli/commands/agent.py b/apps/adk-cli/src/kagenti_cli/commands/agent.py index 635a6a19..6e88a5e7 100644 --- a/apps/adk-cli/src/kagenti_cli/commands/agent.py +++ b/apps/adk-cli/src/kagenti_cli/commands/agent.py @@ -22,6 +22,7 @@ Message, Part, Role, + SendMessageRequest, TaskState, ) from kagenti_adk.a2a.extensions import ( @@ -35,9 +36,18 @@ LLMServiceExtensionSpec, PlatformApiExtensionClient, PlatformApiExtensionSpec, - TrajectoryExtensionClient, + Trajectory, TrajectoryExtensionSpec, ) +from kagenti_adk.a2a.extensions.streaming import ( + ArtifactDelta, + MetadataDelta, + PartDelta, + StateChange, + StreamingExtensionClient, + StreamingExtensionSpec, + TextDelta, +) from kagenti_adk.a2a.extensions.common.form import ( CheckboxField, CheckboxFieldValue, @@ -115,7 +125,7 @@ from rich.markdown import Markdown from rich.table import Column -from kagenti_cli.api import a2a_client +from kagenti_cli.api import a2a_client, make_extension_context from kagenti_cli.async_typer import AsyncTyper, console, create_table, err_console from kagenti_cli.server_utils import announce_server_action, confirm_server_action from kagenti_cli.utils import ( @@ -199,14 +209,18 @@ async def _discover_agent_card(location: str) -> AgentCard: @app.command("add") async def add_agent( - location: typing.Annotated[ - str | None, typer.Argument(help="Agent image or network URL") + location: typing.Annotated[str | None, typer.Argument(help="Agent image or network URL")] = None, + name: typing.Annotated[ + str | None, typer.Option("--name", "-n", help="Agent name (default: derived from image)") ] = None, - name: typing.Annotated[str | None, typer.Option("--name", "-n", help="Agent name (default: derived from image)")] = None, namespace: typing.Annotated[str, typer.Option(help="Target Kubernetes namespace")] = "team1", port: typing.Annotated[int, typer.Option(help="Agent service port")] = 8080, - env: typing.Annotated[list[str] | None, typer.Option("--env", "-e", help="Environment variable in KEY=VALUE format (repeatable)")] = None, - env_file: typing.Annotated[str | None, typer.Option("--env-file", help="Path to env file (KEY=VALUE per line)")] = None, + env: typing.Annotated[ + list[str] | None, typer.Option("--env", "-e", help="Environment variable in KEY=VALUE format (repeatable)") + ] = None, + env_file: typing.Annotated[ + str | None, typer.Option("--env-file", help="Path to env file (KEY=VALUE per line)") + ] = None, yes: typing.Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompts.")] = False, ) -> None: """Add an agent by container image or network URL. [Admin only]""" @@ -394,9 +408,7 @@ async def update_agent( search_path: typing.Annotated[ str | None, typer.Argument(help="Short ID, agent name or part of the provider location of agent to replace") ] = None, - location: typing.Annotated[ - str | None, typer.Argument(help="New agent location (network URL)") - ] = None, + location: typing.Annotated[str | None, typer.Argument(help="New agent location (network URL)")] = None, yes: typing.Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompts.")] = False, ) -> None: """Update an agent's location. [Admin only]""" @@ -760,9 +772,9 @@ async def _run_agent( console_status_stopped = False log_type = None + pending_form_metadata = None - trajectory_spec = TrajectoryExtensionSpec.from_agent_card(agent_card) - trajectory_extension = TrajectoryExtensionClient(trajectory_spec) if trajectory_spec else None + has_trajectory = TrajectoryExtensionSpec.from_agent_card(agent_card) is not None llm_spec = LLMServiceExtensionSpec.from_agent_card(agent_card) embedding_spec = EmbeddingServiceExtensionSpec.from_agent_card(agent_card) platform_extension_spec = PlatformApiExtensionSpec.from_agent_card(agent_card) @@ -860,30 +872,39 @@ async def _run_agent( metadata=metadata, ) - stream = client.send_message(msg) + streaming_spec = StreamingExtensionSpec.from_agent_card(agent_card) + streaming = StreamingExtensionClient(streaming_spec or StreamingExtensionSpec()) + extension_context = make_extension_context([ext.uri for ext in agent_card.capabilities.extensions or []]) + stream = client.send_message(SendMessageRequest(message=msg), context=extension_context) while True: - async for response, task in stream: + async for delta, task_obj in streaming.stream(stream): if not console_status_stopped: console_status_stopped = True console_status.stop() - task_id = task.id if task else task_id - - if response.HasField("status_update"): - update = response.status_update - status = update.status - state = status.state - message = status.message if status.HasField("message") else None - - if state == TaskState.TASK_STATE_COMPLETED: - console.print() # Add newline after completion - return - - elif state in (TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_SUBMITTED): - # Handle streaming content during working state - if message: - if trajectory_extension and (trajectory := trajectory_extension.parse_server_metadata(message)): + task_id = task_obj.id if task_obj else task_id + + match delta: + case TextDelta(delta=text): + if log_type: + err_console.print() + log_type = None + console.print(text, end="") + + case PartDelta(part=part): + if log_type: + err_console.print() + log_type = None + if "text" in part: + console.print(part["text"], end="") + + case MetadataDelta(metadata=meta): + if FormRequestExtensionSpec.URI in meta: + pending_form_metadata = meta[FormRequestExtensionSpec.URI] + if has_trajectory and TrajectoryExtensionSpec.URI in meta: + for entry in meta[TrajectoryExtensionSpec.URI]: + trajectory = Trajectory.model_validate(entry) if update_kind := trajectory.title: if update_kind != log_type: if log_type is not None: @@ -891,109 +912,106 @@ async def _run_agent( err_console.print(f"{update_kind}: ", style="dim", end="") log_type = update_kind err_console.print(trajectory.content or "", style="dim", end="") - else: - # This is regular message content - if log_type: - console.print() - log_type = None - for part in message.parts: - if part.HasField("text"): - console.print(part.text, end="") - - elif state == TaskState.TASK_STATE_INPUT_REQUIRED: - if handle_input is None: - raise ValueError("Agent requires input but no input handler provided") - - if form_metadata := ( - MessageToDict(message.metadata).get(FormRequestExtensionSpec.URI) - if message and message.metadata - else None - ): - stream = client.send_message( - Message( - message_id=str(uuid4()), - parts=[], - role=Role.ROLE_USER, - task_id=task_id, - context_id=context_token.context_id, - metadata={ - FormRequestExtensionSpec.URI: ( - await _ask_form_questions(FormRender.model_validate(form_metadata)) - ).model_dump(mode="json") - }, + + case ArtifactDelta(event=artifact_event): + artifact = artifact_event.artifact + if dump_files_path is None: + continue + dump_files_path.mkdir(parents=True, exist_ok=True) + full_path = dump_files_path / (artifact.name or "unnamed").lstrip("/") + full_path.resolve().relative_to(dump_files_path.resolve()) + full_path.parent.mkdir(parents=True, exist_ok=True) + try: + for part in artifact.parts[:1]: + if part.HasField("raw"): + full_path.write_bytes(part.raw) + console.print(f"📁 Saved {full_path}") + elif part.HasField("url"): + uri = part.url + if uri.startswith("agentstack://"): + async with File.load_content(uri.removeprefix("agentstack://")) as file: + full_path.write_bytes(file.content) + else: + async with httpx.AsyncClient() as httpx_client: + full_path.write_bytes((await httpx_client.get(uri)).content) + console.print(f"📁 Saved {full_path}") + elif part.HasField("text"): + full_path.write_text(part.text) + else: + console.print(f"⚠️ Artifact part {type(part).__name__} is not supported") + if len(artifact.parts) > 1: + console.print("⚠️ Artifact with more than 1 part are not supported.") + except ValueError: + console.print(f"⚠️ Skipping artifact {artifact.name} - outside dump directory") + + case StateChange(state=state): + if log_type: + err_console.print() + log_type = None + if state == TaskState.TASK_STATE_COMPLETED: + console.print() # Add newline after completion + return + + elif state in (TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_SUBMITTED): + pass + + elif state == TaskState.TASK_STATE_INPUT_REQUIRED: + if handle_input is None: + raise ValueError("Agent requires input but no input handler provided") + + if pending_form_metadata: + stream = client.send_message( + SendMessageRequest( + message=Message( + message_id=str(uuid4()), + parts=[], + role=Role.ROLE_USER, + task_id=task_id, + context_id=context_token.context_id, + metadata={ + FormRequestExtensionSpec.URI: ( + await _ask_form_questions( + FormRender.model_validate(pending_form_metadata) + ) + ).model_dump(mode="json") + }, + ) + ), + context=extension_context, ) - ) - break + pending_form_metadata = None + break - text = "" - for part in message.parts if message else []: - if part.HasField("text"): - text = part.text - console.print(f"\n[bold]Agent requires your input[/bold]: {text}\n") - user_input = handle_input() - stream = client.send_message( - Message( - message_id=str(uuid4()), - parts=[Part(text=user_input)], - role=Role.ROLE_USER, - task_id=task_id, - context_id=context_token.context_id, + console.print("\n[bold]Agent requires your input[/bold]\n") + user_input = handle_input() + stream = client.send_message( + SendMessageRequest( + message=Message( + message_id=str(uuid4()), + parts=[Part(text=user_input)], + role=Role.ROLE_USER, + task_id=task_id, + context_id=context_token.context_id, + ) + ), + context=extension_context, ) - ) - break - - elif state in ( - TaskState.TASK_STATE_CANCELED, - TaskState.TASK_STATE_FAILED, - TaskState.TASK_STATE_REJECTED, - ): - error = "" - if message and message.parts and message.parts[0].HasField("text"): - error = message.parts[0].text - console.print(f"\n:boom: [red][bold]Task {TaskState.Name(state)}[/bold][/red]") - console.print(Markdown(error)) - return - - elif state == TaskState.TASK_STATE_AUTH_REQUIRED: - console.print("[yellow]Authentication required[/yellow]") - return + break - else: - console.print(f"[yellow]Unknown task status: {state}[/yellow]") + elif state in ( + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_REJECTED, + ): + console.print(f"\n:boom: [red][bold]Task {TaskState.Name(state)}[/bold][/red]") + return - elif response.HasField("artifact_update"): - artifact = response.artifact_update.artifact - if dump_files_path is None: - continue - dump_files_path.mkdir(parents=True, exist_ok=True) - full_path = dump_files_path / (artifact.name or "unnamed").lstrip("/") - full_path.resolve().relative_to(dump_files_path.resolve()) - full_path.parent.mkdir(parents=True, exist_ok=True) - try: - for part in artifact.parts[:1]: - if part.HasField("raw"): - full_path.write_bytes(part.raw) - console.print(f"📁 Saved {full_path}") - elif part.HasField("url"): - uri = part.url - if uri.startswith("adk://"): - async with File.load_content(uri.removeprefix("adk://")) as file: - full_path.write_bytes(file.content) - else: - async with httpx.AsyncClient() as httpx_client: - full_path.write_bytes((await httpx_client.get(uri)).content) - console.print(f"📁 Saved {full_path}") - elif part.HasField("text"): - full_path.write_text(part.text) - else: - console.print(f"⚠️ Artifact part {type(part).__name__} is not supported") - if len(artifact.parts) > 1: - console.print("⚠️ Artifact with more than 1 part are not supported.") - except ValueError: - console.print(f"⚠️ Skipping artifact {artifact.name} - outside dump directory") + elif state == TaskState.TASK_STATE_AUTH_REQUIRED: + console.print("[yellow]Authentication required[/yellow]") + return - else: - print(response) + else: + console.print(f"[yellow]Unknown task status: {state}[/yellow]") else: break # Stream ended normally @@ -1295,6 +1313,7 @@ async def run_agent( settings=settings_input, dump_files_path=dump_files, handle_input=handle_input, + ) console.print() turn_input = handle_input() diff --git a/apps/adk-cli/uv.lock b/apps/adk-cli/uv.lock index 4f0207f0..9e548f30 100644 --- a/apps/adk-cli/uv.lock +++ b/apps/adk-cli/uv.lock @@ -557,6 +557,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -591,6 +612,7 @@ source = { editable = "../adk-py" } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -598,6 +620,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -617,6 +640,7 @@ dependencies = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -624,6 +648,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-py/pyproject.toml b/apps/adk-py/pyproject.toml index 0424cce9..4e5d2e04 100644 --- a/apps/adk-py/pyproject.toml +++ b/apps/adk-py/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ "typing-extensions>=4.15.0", "opentelemetry-instrumentation-httpx>=0.60b1", "opentelemetry-instrumentation-openai>=0.52.3", + "asgiref>=3.11.0", + "jsonpatch>=1.33", ] [dependency-groups] diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py index 03613f05..76b147ac 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py @@ -19,7 +19,7 @@ from typing_extensions import override from kagenti_adk.a2a.extensions.auth.oauth.storage import MemoryTokenStorageFactory, TokenStorageFactory -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.a2a.types import AgentMessage, AuthRequired, Metadata, RunYieldResume from kagenti_adk.util.pydantic import REVEAL_SECRETS, SecureBaseModel @@ -50,7 +50,6 @@ "OAuthFulfillment", ] -_DEFAULT_DEMAND_NAME = "default" class AuthRequest(SecureBaseModel): @@ -74,19 +73,22 @@ class OAuthExtensionParams(pydantic.BaseModel): """Server requests that the agent requires to be provided by the client.""" -class OAuthExtensionSpec(BaseExtensionSpec[OAuthExtensionParams]): - URI: str = "https://a2a-extensions.adk.kagenti.dev/auth/oauth/v1" - - @classmethod - def single_demand(cls, name: str = _DEFAULT_DEMAND_NAME) -> Self: - return cls(params=OAuthExtensionParams(oauth_demands={name: OAuthDemand()})) - - class OAuthExtensionMetadata(pydantic.BaseModel): oauth_fulfillments: dict[str, OAuthFulfillment] = {} """Provided servers corresponding to the server requests.""" +class OAuthExtensionSpec(BaseExtensionSpec[OAuthExtensionParams, OAuthExtensionMetadata]): + URI: str = "https://a2a-extensions.adk.kagenti.dev/auth/oauth/v1" + + @classmethod + def single_demand(cls, name: str = DEFAULT_DEMAND_NAME, default: OAuthFulfillment | None = None) -> Self: + return cls( + params=OAuthExtensionParams(oauth_demands={name: OAuthDemand()}), + default=OAuthExtensionMetadata(oauth_fulfillments={name: default}) if default else None, + ) + + class OAuthExtensionServer(BaseExtensionServer[OAuthExtensionSpec, OAuthExtensionMetadata]): context: RunContext token_storage_factory: TokenStorageFactory @@ -105,7 +107,7 @@ def _get_fulfillment_for_resource(self, resource_url: pydantic.AnyUrl): raise RuntimeError("No fulfillments found") fulfillment = self.data.oauth_fulfillments.get(str(resource_url)) or self.data.oauth_fulfillments.get( - _DEFAULT_DEMAND_NAME + DEFAULT_DEMAND_NAME ) if fulfillment: return fulfillment diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py index 0193c295..f03660c0 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py @@ -13,7 +13,7 @@ from opentelemetry import trace from typing_extensions import override -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.a2a.types import AgentMessage, AuthRequired from kagenti_adk.util.pydantic import REDACT_SECRETS, REVEAL_SECRETS, SecureBaseModel from kagenti_adk.util.telemetry import flatten_dict @@ -64,15 +64,22 @@ class SecretsServiceExtensionMetadata(pydantic.BaseModel): secret_fulfillments: dict[str, SecretFulfillment] = {} -class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | None]): +class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | None, SecretsServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/auth/secrets/v1" @classmethod - def single_demand(cls, name: str, key: str | None = None, description: str | None = None) -> Self: + def single_demand( + cls, + name: str, + key: str = DEFAULT_DEMAND_NAME, + description: str | None = None, + default: SecretFulfillment | None = None, + ) -> Self: return cls( params=SecretsServiceExtensionParams( - secret_demands={key or "default": SecretDemand(description=description, name=name)} - ) + secret_demands={key: SecretDemand(description=description, name=name)} + ), + default=SecretsServiceExtensionMetadata(secret_fulfillments={key: default}) if default else None, ) diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py index e433722e..3135357f 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import logging import typing from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -29,8 +30,10 @@ ) ParamsT = typing.TypeVar("ParamsT") -MetadataFromClientT = typing.TypeVar("MetadataFromClientT") -MetadataFromServerT = typing.TypeVar("MetadataFromServerT") +MetadataFromClientT = typing.TypeVar("MetadataFromClientT", bound=BaseModel | NoneType) +MetadataFromServerT = typing.TypeVar("MetadataFromServerT", bound=BaseModel | list | NoneType) + +logger = logging.getLogger(__name__) if typing.TYPE_CHECKING: @@ -39,16 +42,18 @@ A2A_EXTENSION_URI = "a2a_extension.uri" A2A_EXTENSION_METADATA_RECEIVED_EVENT = "a2a_extension.metadata.received" +DEFAULT_DEMAND_NAME = "default" def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]: - for base in getattr(cls, "__orig_bases__", ()): - if typing.get_origin(base) is base_class and (args := typing.get_args(base)): - return args + for klass in cls.__mro__: + for base in getattr(klass, "__orig_bases__", ()): + if typing.get_origin(base) is base_class and (args := typing.get_args(base)): + return args raise TypeError(f"Missing Params type for {cls.__name__}") -class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT]): +class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT, MetadataFromClientT]): """ Base class for an A2A extension handler. @@ -76,12 +81,18 @@ def __init_subclass__(cls, **kwargs): Params from the agent card. """ - def __init__(self, params: ParamsT, required: bool = False) -> None: + default: MetadataFromClientT | None = None + """ + Default metadata to use if the client does not provide any. + """ + + def __init__(self, params: ParamsT, required: bool = False, default: MetadataFromClientT | None = None) -> None: """ Agent should construct an extension instance using the constructor. """ self.params = params self.required = required + self.default = default @classmethod def from_agent_card(cls: type[typing.Self], agent: AgentCard) -> typing.Self | None: @@ -111,9 +122,9 @@ def to_agent_card_extensions(self, *, required: bool | None = None) -> list[Agen ] -class NoParamsBaseExtensionSpec(BaseExtensionSpec[NoneType]): - def __init__(self, required: bool = False): - super().__init__(None, required) +class NoParamsBaseExtensionSpec(typing.Generic[MetadataFromClientT], BaseExtensionSpec[NoneType, MetadataFromClientT]): + def __init__(self, required: bool = False, default: MetadataFromClientT | None = None): + super().__init__(None, required, default) @classmethod @override @@ -123,7 +134,7 @@ def from_agent_card(cls, agent: AgentCard) -> typing.Self | None: return None -ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec[typing.Any]) +ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec[typing.Any, typing.Any]) class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromClientT]): @@ -131,6 +142,8 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl Type of the extension metadata, attached to messages. """ + _is_active: bool = False + def __init_subclass__(cls: type[Self], **kwargs): super().__init_subclass__(**kwargs) @@ -150,11 +163,28 @@ def current(cls) -> Self | None: return cls._context_var.get() @property - def data(self) -> MetadataFromClientT | None: - return self._metadata_from_client + def data(self) -> MetadataFromClientT: + if self.MetadataFromClient is NoneType: + return None # type: ignore + + if self._metadata_from_client: + if not self._is_active: + logger.warning("Extension metadata received but extension is not active.") + return self._metadata_from_client + + if self.spec.default is not None: + return self.spec.default + + if not self._is_active: + raise AttributeError(f"Cannot access 'data' attribute: extension '{self.spec.URI}' is not active.") + + raise AttributeError( + f"Extension '{self.spec.URI}' is active but no metadata provided and no default available." + ) def __bool__(self): - return bool(self.data) + # fallback - if we receive metadata but not an extension activation header + return bool(self._is_active or self._metadata_from_client) def __init__(self, spec: ExtensionSpecT, *args, **kwargs) -> None: self.spec = spec @@ -181,6 +211,10 @@ def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, attributes=flatten_dict(self._metadata_from_client.model_dump(context={REDACT_SECRETS: True})), ) + if not self._is_active and request_context.call_context: + self._is_active = self.spec.URI in request_context.call_context.requested_extensions + request_context.call_context.activated_extensions.add(self.spec.URI) + def _fork(self) -> typing.Self: """Creates a clone of this instance with the same arguments as the original""" return type(self)(self.spec, *self._args, **self._kwargs) diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py index a7e60cda..daf6b659 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py @@ -115,14 +115,14 @@ class ApprovalExtensionParams(BaseModel): pass -class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]): - URI: str = "https://a2a-extensions.adk.kagenti.dev/interactions/approval/v1" - - class ApprovalExtensionMetadata(BaseModel): pass +class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams, ApprovalExtensionMetadata]): + URI: str = "https://a2a-extensions.adk.kagenti.dev/interactions/approval/v1" + + class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]): def create_request_message(self, *, request: ApprovalRequest): return AgentMessage( diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py index fb89457f..f867ff16 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py @@ -13,7 +13,7 @@ from a2a.types import Message as A2AMessage from typing_extensions import override -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.util.pydantic import REVEAL_SECRETS, SecureBaseModel, redact_str __all__ = [ @@ -40,6 +40,7 @@ ] + class EmbeddingFulfillment(SecureBaseModel): identifier: str | None = None """ @@ -88,25 +89,32 @@ class EmbeddingServiceExtensionParams(pydantic.BaseModel): """Model requests that the agent requires to be provided by the client.""" -class EmbeddingServiceExtensionSpec(BaseExtensionSpec[EmbeddingServiceExtensionParams]): +class EmbeddingServiceExtensionMetadata(pydantic.BaseModel): + embedding_fulfillments: dict[str, EmbeddingFulfillment] = {} + """Provided models corresponding to the model requests.""" + + +class EmbeddingServiceExtensionSpec( + BaseExtensionSpec[EmbeddingServiceExtensionParams, EmbeddingServiceExtensionMetadata] +): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/embedding/v1" @classmethod def single_demand( - cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = () + cls, + name: str = DEFAULT_DEMAND_NAME, + description: str | None = None, + suggested: tuple[str, ...] = (), + default: EmbeddingFulfillment | None = None, ) -> Self: return cls( params=EmbeddingServiceExtensionParams( - embedding_demands={name or "default": EmbeddingDemand(description=description, suggested=suggested)} - ) + embedding_demands={name: EmbeddingDemand(description=description, suggested=suggested)} + ), + default=EmbeddingServiceExtensionMetadata(embedding_fulfillments={name: default}) if default else None, ) -class EmbeddingServiceExtensionMetadata(pydantic.BaseModel): - embedding_fulfillments: dict[str, EmbeddingFulfillment] = {} - """Provided models corresponding to the model requests.""" - - class EmbeddingServiceExtensionServer( BaseExtensionServer[EmbeddingServiceExtensionSpec, EmbeddingServiceExtensionMetadata] ): diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py index 90569087..93735b5b 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py @@ -31,7 +31,7 @@ class FormServiceExtensionParams(BaseModel): form_demands: FormDemands -class FormServiceExtensionSpec(BaseExtensionSpec[FormServiceExtensionParams]): +class FormServiceExtensionSpec(BaseExtensionSpec[FormServiceExtensionParams, FormServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/form/v1" @classmethod diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py index 2fc16a35..c51ae18a 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py @@ -14,7 +14,7 @@ from a2a.types import Message as A2AMessage from typing_extensions import override -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.util.pydantic import REVEAL_SECRETS, SecureBaseModel, redact_str __all__ = [ @@ -41,6 +41,7 @@ ] + class LLMFulfillment(SecureBaseModel): identifier: str | None = None """ @@ -89,25 +90,30 @@ class LLMServiceExtensionParams(pydantic.BaseModel): """Model requests that the agent requires to be provided by the client.""" -class LLMServiceExtensionSpec(BaseExtensionSpec[LLMServiceExtensionParams]): +class LLMServiceExtensionMetadata(pydantic.BaseModel): + llm_fulfillments: dict[str, LLMFulfillment] = {} + """Provided models corresponding to the model requests.""" + + +class LLMServiceExtensionSpec(BaseExtensionSpec[LLMServiceExtensionParams, LLMServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/llm/v1" @classmethod def single_demand( - cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = () + cls, + name: str = DEFAULT_DEMAND_NAME, + description: str | None = None, + suggested: tuple[str, ...] = (), + default: LLMFulfillment | None = None, ) -> Self: return cls( params=LLMServiceExtensionParams( - llm_demands={name or "default": LLMDemand(description=description, suggested=suggested)} - ) + llm_demands={name: LLMDemand(description=description, suggested=suggested)} + ), + default=LLMServiceExtensionMetadata(llm_fulfillments={name: default}) if default else None, ) -class LLMServiceExtensionMetadata(pydantic.BaseModel): - llm_fulfillments: dict[str, LLMFulfillment] = {} - """Provided models corresponding to the model requests.""" - - class LLMServiceExtensionServer(BaseExtensionServer[LLMServiceExtensionSpec, LLMServiceExtensionMetadata]): @override def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext): diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py index 2559e658..19212611 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py @@ -18,7 +18,7 @@ from typing_extensions import override from kagenti_adk.a2a.extensions.auth.oauth.oauth import OAuthExtensionServer -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.a2a.extensions.services.platform import PlatformApiExtensionServer from kagenti_adk.platform.client import get_platform_client from kagenti_adk.util.logging import logger @@ -55,7 +55,6 @@ _TRANSPORT_TYPES = Literal["streamable_http", "stdio"] -_DEFAULT_DEMAND_NAME = "default" _DEFAULT_ALLOWED_TRANSPORTS: list[_TRANSPORT_TYPES] = ["streamable_http"] @@ -115,16 +114,22 @@ class MCPServiceExtensionParams(pydantic.BaseModel): """Server requests that the agent requires to be provided by the client.""" -class MCPServiceExtensionSpec(BaseExtensionSpec[MCPServiceExtensionParams]): +class MCPServiceExtensionMetadata(pydantic.BaseModel): + mcp_fulfillments: dict[str, MCPFulfillment] = {} + """Provided servers corresponding to the server requests.""" + + +class MCPServiceExtensionSpec(BaseExtensionSpec[MCPServiceExtensionParams, MCPServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/mcp/v1" @classmethod def single_demand( cls, - name: str = _DEFAULT_DEMAND_NAME, + name: str = DEFAULT_DEMAND_NAME, description: str | None = None, suggested: tuple[str, ...] = (), allowed_transports: list[_TRANSPORT_TYPES] | None = None, + default: MCPFulfillment | None = None, ) -> Self: return cls( params=MCPServiceExtensionParams( @@ -135,15 +140,11 @@ def single_demand( allowed_transports=allowed_transports or _DEFAULT_ALLOWED_TRANSPORTS, ) } - ) + ), + default=MCPServiceExtensionMetadata(mcp_fulfillments={name: default}) if default else None, ) -class MCPServiceExtensionMetadata(pydantic.BaseModel): - mcp_fulfillments: dict[str, MCPFulfillment] = {} - """Provided servers corresponding to the server requests.""" - - class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCPServiceExtensionMetadata]): @override def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext): @@ -173,7 +174,7 @@ def parse_client_metadata(self, message: A2AMessage) -> MCPServiceExtensionMetad return metadata @asynccontextmanager - async def create_client(self, demand: str = _DEFAULT_DEMAND_NAME): + async def create_client(self, demand: str = DEFAULT_DEMAND_NAME): fulfillment = self.data.mcp_fulfillments.get(demand) if self.data else None if not fulfillment: diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py index e64ea62c..d4a2f450 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py @@ -66,7 +66,7 @@ class PlatformApiExtensionParams(pydantic.BaseModel): auto_use: bool = True -class PlatformApiExtensionSpec(BaseExtensionSpec[PlatformApiExtensionParams]): +class PlatformApiExtensionSpec(BaseExtensionSpec[PlatformApiExtensionParams, PlatformApiExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/platform_api/v1" def __init__(self, params: PlatformApiExtensionParams | None = None) -> None: @@ -152,7 +152,9 @@ class _PlatformSelfRegistrationExtensionParams(pydantic.BaseModel): self_registration_id: str -class _PlatformSelfRegistrationExtensionSpec(BaseExtensionSpec[_PlatformSelfRegistrationExtensionParams]): +class _PlatformSelfRegistrationExtensionSpec( + BaseExtensionSpec[_PlatformSelfRegistrationExtensionParams, _PlatformSelfRegistrationExtension] +): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/platform-self-registration/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py new file mode 100644 index 00000000..2467f7a7 --- /dev/null +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py @@ -0,0 +1,335 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + + +from __future__ import annotations + +from collections.abc import AsyncIterator +from enum import StrEnum +from types import NoneType +from typing import Any, cast + +from a2a.client.client import ClientEvent +from a2a.types import ( + AgentExtension, + Message, + TaskArtifactUpdateEvent, + TaskStatus, + TaskStatusUpdateEvent, +) +from google.protobuf.json_format import MessageToDict +from pydantic import BaseModel + +from kagenti_adk.a2a.extensions.base import ( + BaseExtensionClient, + BaseExtensionServer, + NoParamsBaseExtensionSpec, +) +from kagenti_adk.a2a.types import Metadata +from kagenti_adk.types import JsonPatch, JsonValue + + +class StreamingExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): + URI = "https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1" + DESCRIPTION = "Enables fine-grained streaming of token chunks." + + +class StreamOperations(StrEnum): + MESSAGE_UPDATE = "message_update" + + +class StreamingExtensionServer(BaseExtensionServer[StreamingExtensionSpec, NoneType]): + """ + Adds streaming support to the A2A protocol through TaskStatusUpdateEvent.metadata object. + + Updates are emitted as metadata objects containing JSON Patch (RFC 6902) operations, + extended with `str_ins` from json-crdt-patch for efficient text streaming. + + Supported operations: + - replace: initialize the message draft (root replace with message_id and parts) + - add: adding parts to the message + - str_ins: streaming individual text chunks (=llm tokens) + + The stream is a sequential log of patches, that are applied to a final message: + --- + update: {..., "https://.../streaming": {"message_update": { "op": "replace", "path": "", "value": {"message_id": "...", "parts": [{"text": "Hello "}]} }, "message_id": "..."}} + update: {..., "https://.../streaming": {"message_update": { "op": "str_ins", "path": "/parts/0/text", "pos": 6, "value": "world" }, "message_id": "..."}} + update: {..., "https://.../streaming": {"message_update": { "op": "str_ins", "path": "/parts/0/text", "pos": 11, "value": "!" }, "message_id": "..."}} + """ + + def to_metadata(self, patches: JsonPatch, message_id: str | None = None) -> Metadata: + payload: dict[str, Any] = {StreamOperations.MESSAGE_UPDATE: patches} + if message_id is not None: + payload["message_id"] = message_id + return Metadata({self.spec.URI: cast(JsonValue, payload)}) + + +# --- Client-side delta types --- + + +class TextDelta(BaseModel): + """A text chunk appended to an existing text part.""" + + part_index: int + delta: str + + +class PartDelta(BaseModel): + """A new part was added to the message.""" + + part_index: int + part: dict[str, Any] + + +class MetadataDelta(BaseModel): + """Message metadata was added or updated.""" + + metadata: dict[str, Any] + + +class ArtifactDelta(BaseModel, arbitrary_types_allowed=True): + """An artifact update event.""" + + event: TaskArtifactUpdateEvent + + +class StateChange(BaseModel, arbitrary_types_allowed=True): + """A task state transition (WORKING, COMPLETED, INPUT_REQUIRED, etc.).""" + + state: int # TaskState enum value + message: Message | None = None + + +StreamDelta = TextDelta | PartDelta | MetadataDelta | ArtifactDelta | StateChange + + +class StreamingExtensionClient(BaseExtensionClient[StreamingExtensionSpec, NoneType]): + """Client-side streaming consumer. + + Wraps raw A2A ``ClientEvent`` streams into a unified delta-based API. + Works identically whether the server supports the streaming extension or not: + + - **With streaming**: patches are applied incrementally; full messages whose + ``message_id`` was already streamed are suppressed. + - **Without streaming**: full messages are decomposed into ``PartDelta`` / + ``MetadataDelta`` / ``StateChange`` events so the consumer code is the same. + + Usage:: + + streaming = StreamingExtensionClient(spec) + async for delta, task in streaming.stream(client.send_message(msg)): + match delta: + case TextDelta(delta=text): + print(text, end="", flush=True) + case PartDelta(part=part): + ... + case StateChange(state=TaskState.TASK_STATE_COMPLETED): + print() + case ArtifactDelta(event=evt): + ... + """ + + def __init__(self, spec: StreamingExtensionSpec) -> None: + super().__init__(spec) + self._draft: dict[str, Any] = {} + self._message_id: str | None = None + self._streamed_messages: dict[str, int] = {} # message_id -> parts_count + + @property + def draft(self) -> dict[str, Any]: + """Current draft message built from applied patches.""" + return self._draft + + @property + def message_id(self) -> str | None: + """Current message_id being built from patches.""" + return self._message_id + + def to_agent_card_extensions(self, **kwargs) -> list[AgentExtension]: + return self.spec.to_agent_card_extensions(required=False) + + async def stream( + self, + events: AsyncIterator[ClientEvent], + ) -> AsyncIterator[tuple[StreamDelta, Any | None]]: + """Consume a raw A2A event stream and yield ``(delta, task)`` pairs. + + The method handles reconciliation automatically: messages that were + already streamed via patches are suppressed, and merged messages + (draft + explicit) only emit the new parts beyond the streamed prefix. + """ + async for response, task in events: + if response.HasField("artifact_update"): + yield ArtifactDelta(event=response.artifact_update), task + continue + + if not response.HasField("status_update"): + continue + + update: TaskStatusUpdateEvent = response.status_update + patch_data = self._extract_patches(update) + + if patch_data is not None: + # Streaming mode: apply patch and emit deltas + for delta in self._apply_and_emit(patch_data): + yield delta, task + continue + + # Non-streaming event (full message or state change) + status: TaskStatus = update.status + message: Message | None = status.message if status.HasField("message") else None + + if message and message.message_id and message.message_id in self._streamed_messages: + # This message was already streamed via patches + streamed_count = self._streamed_messages[message.message_id] + parts_list = list(message.parts) + + if len(parts_list) > streamed_count: + # Merged message: emit only the new parts beyond the streamed prefix + for i, part in enumerate(parts_list[streamed_count:], start=streamed_count): + yield PartDelta(part_index=i, part=MessageToDict(part)), task + + # Emit state change with the full message for reference + if MessageToDict(status): + yield StateChange(state=status.state, message=message), task + + # Clean up tracking for this message + del self._streamed_messages[message.message_id] + self._draft = {} + self._message_id = None + continue + + # Non-streamed message: decompose into deltas + if message and message.message_id: + for i, part in enumerate(message.parts): + yield PartDelta(part_index=i, part=MessageToDict(part)), task + meta = MessageToDict(message.metadata) + if meta: + yield MetadataDelta(metadata=meta), task + + if MessageToDict(status): + yield StateChange(state=status.state, message=message), task + + def text_delta(self, event: TaskStatusUpdateEvent) -> str | None: + """Extract text delta from streaming patch in event metadata. + + Returns the text content of the first text-producing patch, or ``None`` + if the event does not contain a text-producing streaming patch. + """ + patches = self._extract_patches(event) + if patches is None: + return None + for patch in patches: + if (text := self._patch_text_delta(patch)) is not None: + return text + return None + + def apply_patch(self, event: TaskStatusUpdateEvent) -> dict[str, Any] | None: + """Apply streaming patches to internal draft. Returns current draft state, or ``None`` if no patches.""" + patches = self._extract_patches(event) + if patches is None: + return None + self._apply_patches_to_draft(patches) + return self._draft + + def _extract_patches(self, event: TaskStatusUpdateEvent) -> list[dict[str, Any]] | None: + """Extract the streaming patch list from event metadata, if present.""" + if not event.HasField("metadata"): + return None + meta = MessageToDict(event.metadata) + if not meta or self.spec.URI not in meta: + return None + ext_data = meta[self.spec.URI] + if not isinstance(ext_data, dict): + return None + patches = ext_data.get(StreamOperations.MESSAGE_UPDATE) + if not isinstance(patches, list): + return None + + # Track message_id from extension metadata + msg_id = ext_data.get("message_id") + if msg_id and isinstance(msg_id, str): + self._message_id = msg_id + + return patches + + def _apply_patches_to_draft(self, patches: list[dict[str, Any]]) -> None: + """Apply a list of patch operations to the internal draft.""" + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + + self._draft = ExtendedJsonPatch(patches).apply(self._draft) + # Update tracking + if self._message_id: + parts = self._draft.get("parts", []) + self._streamed_messages[self._message_id] = len(parts) + + def _apply_and_emit(self, patches: list[dict[str, Any]]) -> list[StreamDelta]: + """Apply patches and return the corresponding deltas.""" + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + + deltas: list[StreamDelta] = [] + metadata_ops: list[dict[str, Any]] = [] + parts_before = len(self._draft.get("parts", [])) + + self._apply_patches_to_draft(patches) + + add_parts_seen = 0 + for patch in patches: + op = patch.get("op") + path = patch.get("path", "") + value = patch.get("value") + + if op == "str_ins": + segments = path.split("/") + if len(segments) >= 3 and segments[1] == "parts": + part_index = int(segments[2]) + deltas.append(TextDelta(part_index=part_index, delta=patch.get("value", ""))) + + elif op == "replace" and path == "": + if isinstance(value, dict): + for i, part in enumerate(value.get("parts", [])): + deltas.append(PartDelta(part_index=i, part=part)) + if meta := value.get("metadata"): + metadata_ops.append({"op": "replace", "path": "", "value": meta}) + + elif op == "add" and path == "/parts/-": + part_index = parts_before + add_parts_seen + add_parts_seen += 1 + if isinstance(value, dict): + deltas.append(PartDelta(part_index=part_index, part=value)) + + elif path.startswith("/metadata"): + # Strip /metadata prefix and collect for incremental application + metadata_ops.append({**patch, "path": path[len("/metadata"):]}) + + if metadata_ops: + incremental = ExtendedJsonPatch(metadata_ops).apply({}) + if incremental: + deltas.append(MetadataDelta(metadata=incremental)) + + return deltas + + @staticmethod + def _patch_text_delta(patch: dict[str, Any]) -> str | None: + """Extract a text delta from a single patch operation.""" + op = patch.get("op") + value = patch.get("value") + path = patch.get("path", "") + + if op == "str_ins": + return value if isinstance(value, str) else None + + if op == "replace" and path == "": + # Root replace -- extract text from first part if any + if isinstance(value, dict): + parts = value.get("parts", []) + if parts and isinstance(parts[0], dict) and "text" in parts[0]: + return parts[0]["text"] + return None + + if op == "add" and "/parts/" in path: + if isinstance(value, dict) and "text" in value: + return value["text"] + return None + + return None diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py index d0bfdc44..7364ae53 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py @@ -81,14 +81,14 @@ class ToolCallExtensionParams(BaseModel): pass -class ToolCallExtensionSpec(BaseExtensionSpec[ToolCallExtensionParams]): - URI: str = "https://a2a-extensions.adk.kagenti.dev/tools/call/v1" - - class ToolCallExtensionMetadata(BaseModel): pass +class ToolCallExtensionSpec(BaseExtensionSpec[ToolCallExtensionParams, ToolCallExtensionMetadata]): + URI: str = "https://a2a-extensions.adk.kagenti.dev/tools/call/v1" + + class ToolCallExtensionServer(BaseExtensionServer[ToolCallExtensionSpec, ToolCallExtensionMetadata]): def create_request_message(self, *, request: ToolCallRequest): return AgentMessage( diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py index 85e56c35..907bfe67 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py @@ -2,11 +2,114 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from .agent_detail import ( + AgentDetail, + AgentDetailContributor, + AgentDetailExtensionClient, + AgentDetailExtensionServer, + AgentDetailExtensionSpec, + AgentDetailTool, + EnvVar, +) +from .canvas import ( + CanvasEditRequest, + CanvasEditRequestMetadata, + CanvasExtensionServer, + CanvasExtensionSpec, +) +from .citation import ( + Citation, + CitationExtensionClient, + CitationExtensionServer, + CitationExtensionSpec, + CitationMetadata, +) +from .error import ( + DEFAULT_ERROR_EXTENSION, + Error, + ErrorContext, + ErrorExtensionClient, + ErrorExtensionParams, + ErrorExtensionServer, + ErrorExtensionSpec, + ErrorGroup, + ErrorMetadata, + get_error_extension_context, + use_error_extension_context, +) +from .form_request import ( + FormRequestExtensionClient, + FormRequestExtensionServer, + FormRequestExtensionSpec, +) +from .settings import ( + AgentRunSettings, + CheckboxField, + CheckboxFieldValue, + CheckboxGroupField, + CheckboxGroupFieldValue, + OptionItem, + SettingsExtensionClient, + SettingsExtensionServer, + SettingsExtensionSpec, + SettingsFieldValue, + SettingsRender, + SingleSelectField, + SingleSelectFieldValue, +) +from .trajectory import ( + Trajectory, + TrajectoryExtensionClient, + TrajectoryExtensionServer, + TrajectoryExtensionSpec, +) -from .agent_detail import * -from .canvas import * -from .citation import * -from .error import * -from .form_request import * -from .settings import * -from .trajectory import * +__all__ = [ + "AgentDetail", + "AgentDetailContributor", + "AgentDetailExtensionClient", + "AgentDetailExtensionServer", + "AgentDetailExtensionSpec", + "AgentDetailTool", + "AgentRunSettings", + "CanvasEditRequest", + "CanvasEditRequestMetadata", + "CanvasExtensionServer", + "CanvasExtensionSpec", + "CheckboxField", + "CheckboxFieldValue", + "CheckboxGroupField", + "CheckboxGroupFieldValue", + "Citation", + "CitationExtensionClient", + "CitationExtensionServer", + "CitationExtensionSpec", + "CitationMetadata", + "DEFAULT_ERROR_EXTENSION", + "EnvVar", + "Error", + "ErrorContext", + "ErrorExtensionClient", + "ErrorExtensionParams", + "ErrorExtensionServer", + "ErrorExtensionSpec", + "ErrorGroup", + "ErrorMetadata", + "FormRequestExtensionClient", + "FormRequestExtensionServer", + "FormRequestExtensionSpec", + "OptionItem", + "SettingsExtensionClient", + "SettingsExtensionServer", + "SettingsExtensionSpec", + "SettingsFieldValue", + "SettingsRender", + "SingleSelectField", + "SingleSelectFieldValue", + "Trajectory", + "TrajectoryExtensionClient", + "TrajectoryExtensionServer", + "TrajectoryExtensionSpec", + "get_error_extension_context", + "use_error_extension_context", +] diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py index f867f693..2170a910 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py @@ -44,7 +44,7 @@ class AgentDetail(pydantic.BaseModel, extra="allow"): variables: list[EnvVar] | None = None -class AgentDetailExtensionSpec(BaseExtensionSpec[AgentDetail]): +class AgentDetailExtensionSpec(BaseExtensionSpec[AgentDetail, NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/agent-detail/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py index 2d274743..a9371387 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py @@ -57,7 +57,7 @@ def parse_artifact(cls, v): return v -class CanvasExtensionSpec(NoParamsBaseExtensionSpec): +class CanvasExtensionSpec(NoParamsBaseExtensionSpec[CanvasEditRequestMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/canvas/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py index 131d41dd..8a9d7eb0 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py @@ -54,13 +54,13 @@ class CitationMetadata(pydantic.BaseModel): citations: list[Citation] = pydantic.Field(default_factory=list) -class CitationExtensionSpec(NoParamsBaseExtensionSpec): +class CitationExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/citation/v1" class CitationExtensionServer(BaseExtensionServer[CitationExtensionSpec, NoneType]): def citation_metadata(self, *, citations: list[Citation]) -> Metadata: - return Metadata({self.spec.URI: CitationMetadata(citations=citations).model_dump(mode="json")}) + return Metadata({self.spec.URI: [c.model_dump(mode="json") for c in citations]}) def message( self, @@ -76,4 +76,4 @@ def message( ) -class CitationExtensionClient(BaseExtensionClient[CitationExtensionSpec, CitationMetadata]): ... +class CitationExtensionClient(BaseExtensionClient[CitationExtensionSpec, list[Citation]]): ... diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py index fbfbfdfb..3284ba26 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py @@ -85,7 +85,7 @@ class ErrorExtensionParams(pydantic.BaseModel): include_stacktrace: bool = False -class ErrorExtensionSpec(BaseExtensionSpec[ErrorExtensionParams]): +class ErrorExtensionSpec(BaseExtensionSpec[ErrorExtensionParams, NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/error/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py index 260a7182..b175cf71 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py @@ -37,7 +37,7 @@ T = TypeVar("T") -class FormRequestExtensionSpec(NoParamsBaseExtensionSpec): +class FormRequestExtensionSpec(NoParamsBaseExtensionSpec[FormResponse]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/form_request/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py index 9de81a45..538fd066 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py @@ -72,7 +72,7 @@ class AgentRunSettings(BaseModel): @deprecated("Use FormServiceExtensionSpec.demand_settings() instead") -class SettingsExtensionSpec(BaseExtensionSpec[SettingsRender | None]): +class SettingsExtensionSpec(BaseExtensionSpec[SettingsRender | None, AgentRunSettings]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/settings/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py index ac4d6eaf..2a06de95 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py @@ -42,7 +42,7 @@ class Trajectory(pydantic.BaseModel): group_id: str | None = None -class TrajectoryExtensionSpec(NoParamsBaseExtensionSpec): +class TrajectoryExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/trajectory/v1" @@ -51,7 +51,11 @@ def trajectory_metadata( self, *, title: str | None = None, content: str | None = None, group_id: str | None = None ) -> Metadata: return Metadata( - {self.spec.URI: Trajectory(title=title, content=content, group_id=group_id).model_dump(mode="json")} + { + self.spec.URI: [ + Trajectory(title=title, content=content, group_id=group_id).model_dump(mode="json"), + ] + } ) def message( diff --git a/apps/adk-py/src/kagenti_adk/server/accumulator.py b/apps/adk-py/src/kagenti_adk/server/accumulator.py new file mode 100644 index 00000000..349afdb5 --- /dev/null +++ b/apps/adk-py/src/kagenti_adk/server/accumulator.py @@ -0,0 +1,219 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from typing import Self + +from a2a.types import ( + Message, + Part, + Role, + TaskStatus, + TaskStatusUpdateEvent, +) +from pydantic import BaseModel, Field + +from kagenti_adk.a2a.types import Metadata, RunYield +from kagenti_adk.types import JsonPatch, JsonPatchOp + + +class TextPartContext(BaseModel, arbitrary_types_allowed=True): + chunks: list[str] = Field(default_factory=list) + message_context: MessageContext + part_index: int + pos: int = 0 + + def add_chunk(self, chunk: str) -> JsonPatch: + from google.protobuf.json_format import MessageToDict + + self.chunks.append(chunk) + part_dict = MessageToDict(Part(text=chunk)) + if self.pos == 0: + if not self.message_context.initialized: + self.message_context.initialized = True + msg_id = str(self.message_context.message_id) + patch: JsonPatchOp = { + "op": "replace", + "path": "", + "value": {"message_id": msg_id, "parts": [part_dict]}, + } + else: + patch = {"op": "add", "path": "/parts/-", "value": part_dict} + else: + patch = {"op": "str_ins", "pos": self.pos, "path": f"/parts/{self.part_index}/text", "value": chunk} + self.pos += len(chunk) + return [patch] + + def build(self) -> Part: + return Part(text="".join(self.chunks)) + + +class MessageContext(BaseModel, arbitrary_types_allowed=True): + parts: list[Part] = Field(default_factory=list) + metadata: Metadata | None = None + initialized: bool = False + message_id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def add_metadata(self, metadata: Metadata) -> JsonPatch: + from kagenti_adk.server.jsonpatch_ext import make_patch + from kagenti_adk.server.utils import merge_metadata + + if self.metadata is None: + self.metadata = Metadata(metadata) + if not self.initialized: + self.initialized = True + return [ + { + "op": "replace", + "path": "", + "value": {"message_id": str(self.message_id), "parts": [], "metadata": dict(self.metadata)}, + } + ] + return [{"op": "add", "path": "/metadata", "value": dict(self.metadata)}] + old_metadata = dict(self.metadata) + self.metadata = merge_metadata(self.metadata, metadata) + new_metadata = dict(self.metadata) + + ops: JsonPatch = [] + for op in make_patch(old_metadata, new_metadata).patch: + rewritten: JsonPatchOp = {"op": op["op"], "path": f"/metadata{op['path']}"} + if "value" in op: + rewritten["value"] = op["value"] + if "pos" in op: + rewritten["pos"] = op["pos"] + ops.append(rewritten) + return ops + + def add_part(self, part: Part) -> JsonPatch: + from google.protobuf.json_format import MessageToDict + + self.parts.append(part) + part_dict = MessageToDict(part) + if not self.initialized: + self.initialized = True + return [{"op": "replace", "path": "", "value": {"message_id": str(self.message_id), "parts": [part_dict]}}] + return [{"op": "add", "path": "/parts/-", "value": part_dict}] + + def build(self) -> Message: + m = Message(message_id=str(self.message_id), role=Role.ROLE_AGENT) + m.parts.extend(self.parts) + if self.metadata: + for k, v in self.metadata.items(): + m.metadata[k] = v + return m + + +@dataclass +class ProcessResult: + """Result of processing a yield through the accumulator.""" + + accumulated: bool = False + """True if the value was consumed by the accumulator (str, Part, Metadata).""" + + draft: Message | None = None + """Flushed message from accumulated state, when a non-accumulating yield triggers a flush.""" + + patch: JsonPatch | None = None + """Streaming patches to send as a partial update (JSON Patch operations).""" + + message_id: str | None = None + """The message_id of the current accumulation cycle, for client-side correlation.""" + + +class MessageAccumulator: + """Manages the streaming accumulation state machine. + + Accumulates string chunks, Parts, and Metadata into messages, + flushing when non-accumulating yields (Message, TaskStatus, etc.) arrive. + + The state machine has 3 levels: + - Base level (Self): no accumulation in progress + - MessageContext: accumulating parts and metadata into a message + - TextPartContext: accumulating string chunks into a single text Part + """ + + def __init__(self) -> None: + self._active: Self | MessageContext | TextPartContext = self + + @property + def active_context(self) -> MessageAccumulator | MessageContext | TextPartContext: + return self._active + + def process(self, value: RunYield) -> ProcessResult: + """Process a yield value through the state machine. + + Returns a ProcessResult describing what happened: + - accumulated=True: the value was consumed. patch may contain a streaming update. + - accumulated=False: the value is a "control" yield (Message, TaskStatus, etc.) + that the caller should handle. draft may contain a flushed accumulated message. + """ + match self._active: + case MessageAccumulator(): + return self._process_at_base_level(value) + case MessageContext() as ctx: + return self._process_at_message_level(ctx, value) + case TextPartContext() as ctx: + return self._process_at_text_part_level(ctx, value) + + def flush(self) -> Message | None: + """Flush any accumulated state into a message. Resets to base level.""" + match self._active: + case TextPartContext() as ctx: + ctx.message_context.add_part(ctx.build()) + msg = ctx.message_context.build() + case MessageContext() as ctx: + msg = ctx.build() + case _: + return None + self._active = self + return msg + + def _process_at_base_level(self, value: RunYield) -> ProcessResult: + match value: + case Message() | TaskStatus() | TaskStatusUpdateEvent(): + return ProcessResult(accumulated=False) + case _: + self._active = MessageContext() + return self._process_at_message_level(self._active, value) + + def _process_at_message_level(self, ctx: MessageContext, value: RunYield) -> ProcessResult: + msg_id = str(ctx.message_id) + match value: + case Part() as part: + patch = ctx.add_part(part) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case Metadata() as metadata: + patch = ctx.add_metadata(metadata) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case dict() as data: + patch = ctx.add_part(self._dict_to_part(data)) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case str(): + self._active = TextPartContext(message_context=ctx, part_index=len(ctx.parts)) + return self._process_at_text_part_level(self._active, value) + case _: + draft = ctx.build() + self._active = self + return ProcessResult(accumulated=False, draft=draft) + + @staticmethod + def _dict_to_part(data: dict) -> Part: + from google.protobuf.struct_pb2 import Struct, Value + + s = Struct() + s.update(data) + return Part(data=Value(struct_value=s)) + + def _process_at_text_part_level(self, ctx: TextPartContext, value: RunYield) -> ProcessResult: + msg_id = str(ctx.message_context.message_id) + match value: + case str(text): + patch = ctx.add_chunk(text) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case _: + ctx.message_context.add_part(ctx.build()) + self._active = ctx.message_context + return self._process_at_message_level(ctx.message_context, value) diff --git a/apps/adk-py/src/kagenti_adk/server/agent.py b/apps/adk-py/src/kagenti_adk/server/agent.py index 94a03f9d..535066e2 100644 --- a/apps/adk-py/src/kagenti_adk/server/agent.py +++ b/apps/adk-py/src/kagenti_adk/server/agent.py @@ -35,7 +35,8 @@ from google.protobuf import message as _message from typing_extensions import override -from kagenti_adk.a2a.extensions import AgentDetailExtensionSpec, BaseExtensionServer +from kagenti_adk.a2a.extensions import BaseExtensionServer +from kagenti_adk.a2a.extensions.streaming import StreamingExtensionServer from kagenti_adk.a2a.extensions.ui.agent_detail import ( AgentDetail, AgentDetailExtensionSpec, @@ -45,12 +46,14 @@ get_error_extension_context, ) from kagenti_adk.a2a.types import Metadata, RunYield, RunYieldResume, validate_message +from kagenti_adk.server.accumulator import MessageAccumulator from kagenti_adk.server.constants import _DEFAULT_AGENT_INTERFACE, _DEFAULT_AGENT_SKILL, DEFAULT_IMPLICIT_EXTENSIONS from kagenti_adk.server.context import RunContext from kagenti_adk.server.dependencies import Dependency, Depends, extract_dependencies +from kagenti_adk.server.exceptions import InvalidYieldError from kagenti_adk.server.store.context_store import ContextStore -from kagenti_adk.server.utils import cancel_task -from kagenti_adk.types import A2ASecurity +from kagenti_adk.server.utils import cancel_task, merge_messages +from kagenti_adk.types import A2ASecurity, JsonPatch from kagenti_adk.util.logging import logger AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]] @@ -285,15 +288,17 @@ def decorator(fn: OriginalFnType) -> Agent: if inspect.isasyncgenfunction(fn): async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: + gen: AsyncGenerator[RunYield, RunYieldResume] = fn(*args, **kwargs) try: - gen: AsyncGenerator[RunYield, RunYieldResume] = fn(*args, **kwargs) value: RunYieldResume = None while True: - value = await _ctx.yield_async(await gen.asend(value)) + result = await gen.asend(value) + try: + value = await _ctx.yield_async(result) + except Exception as e: + value = await _ctx.yield_async(await gen.athrow(e)) except StopAsyncIteration: pass - except Exception as e: - await _ctx.yield_async(e) finally: _ctx.shutdown() @@ -301,24 +306,26 @@ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: try: - await _ctx.yield_async(await fn(*args, **kwargs)) - except Exception as e: - await _ctx.yield_async(e) + result = await fn(*args, **kwargs) + if result is not None: + await _ctx.yield_async(result) finally: _ctx.shutdown() elif inspect.isgeneratorfunction(fn): def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None: + gen: Generator[RunYield, RunYieldResume] = fn(*args, **kwargs) try: - gen: Generator[RunYield, RunYieldResume] = fn(*args, **kwargs) - value = None + value: RunYieldResume = None while True: - value = _ctx.yield_sync(gen.send(value)) + result = gen.send(value) + try: + value = _ctx.yield_sync(result) + except Exception as e: + value = _ctx.yield_sync(gen.throw(e)) except StopIteration: pass - except Exception as e: - _ctx.yield_sync(e) finally: _ctx.shutdown() @@ -329,9 +336,9 @@ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None: try: - _ctx.yield_sync(fn(*args, **kwargs)) - except Exception as e: - _ctx.yield_sync(e) + result = fn(*args, **kwargs) + if result is not None: + _ctx.yield_sync(result) finally: _ctx.shutdown() @@ -362,6 +369,7 @@ def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callabl self._on_finish: Callable[[], None] | None = on_finish self._working: bool = False self._dependency_container: ActiveDependenciesContainer | None = None + self._accumulator: MessageAccumulator = MessageAccumulator() @property def run_context(self) -> RunContext: @@ -440,16 +448,125 @@ async def cancel(self, request_context: RequestContext, event_queue: EventQueue) finally: await cancel_task(self._task) - def _with_context(self, message: Message | None = None) -> Message | None: - if message: - message.context_id = self.task_updater.context_id - message.task_id = self.task_updater.task_id - return message + def _prepare_message(self, message: Message | None = None, msg_draft: Message | None = None) -> Message | None: + for msg in (message, msg_draft): + if msg: + msg.context_id = self.task_updater.context_id + msg.task_id = self.task_updater.task_id + msgs = [m for m in (msg_draft, message) if m] + return merge_messages(*msgs) if msgs else None + + async def _send_partial_update(self, patches: JsonPatch, message_id: str | None = None): + if not (ext := StreamingExtensionServer.current()): + return + await self.task_updater.update_status( + state=TaskState.TASK_STATE_WORKING, metadata=ext.to_metadata(patches, message_id=message_id) + ) - async def _run_agent_function(self, initial_message: Message) -> None: + async def _handle_message_yield(self, yielded_value: RunYield) -> RunYieldResume: + result = self._accumulator.process(yielded_value) + if result.accumulated: + if result.patch: + await self._send_partial_update(result.patch, message_id=result.message_id) + return None + return await self._dispatch_control_yield(yielded_value, result.draft) + + async def _dispatch_control_yield( + self, yielded_value: RunYield, draft: Message | None = None + ) -> RunYieldResume: + match yielded_value: + case Message() as message: + await self.task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=self._prepare_message(message, draft), + ) + case TaskStatus( + state=(TaskState.TASK_STATE_AUTH_REQUIRED | TaskState.TASK_STATE_INPUT_REQUIRED) as state, + message=message, + ): + await self.task_updater.update_status( + state=state, + message=self._prepare_message(message, draft), + ) + self._working = False + resume_value = await self.resume_queue.get() + self.resume_queue.task_done() + return resume_value + case TaskStatus(state=state, message=message): + await self.task_updater.update_status( + state=state, + message=self._prepare_message(message, draft), + ) + case TaskStatusUpdateEvent( + status=TaskStatus(state=state, message=message), + metadata=metadata, + ): + await self.task_updater.update_status( + state=state, + message=self._prepare_message(message, draft), + metadata=dict(metadata), + ) + + async def _agent_loop(self, task: asyncio.Task): yield_queue = self.run_context._yield_queue yield_resume_queue = self.run_context._yield_resume_queue + resume_value: RunYieldResume | Exception = None + opened_artifacts: set[str] = set() + + while not task.done() or yield_queue.async_q.qsize() > 0: + yielded_value = await yield_queue.async_q.get() + resume_value = None + self.last_invocation = datetime.now() + + if isinstance(yielded_value, _message.Message): + validate_message(yielded_value) + + try: + match yielded_value: + case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata): + last_chunk = True + if "_last_chunk" in metadata: + last_chunk = bool(metadata["_last_chunk"]) + del metadata["_last_chunk"] + append = artifact_id in opened_artifacts + if not last_chunk: + opened_artifacts.add(artifact_id) + elif artifact_id in opened_artifacts: + opened_artifacts.remove(artifact_id) + + await self.task_updater.add_artifact( + parts=list(parts), + artifact_id=artifact_id, + name=name, + metadata=dict(metadata), + last_chunk=last_chunk, + append=append, + ) + + case TaskArtifactUpdateEvent( + artifact=Artifact(artifact_id=artifact_id, name=name, metadata=metadata, parts=parts), + append=append, + last_chunk=last_chunk, + ): + await self.task_updater.add_artifact( + parts=list(parts), + artifact_id=artifact_id, + name=name, + metadata=dict(metadata), + append=append, + last_chunk=last_chunk, + ) + case Part() | dict() | Metadata() | str() | TaskStatus() | TaskStatusUpdateEvent() | Message(): + resume_value = await self._handle_message_yield(yielded_value) + case _: + raise InvalidYieldError(yielded_value) + except Exception as e: + resume_value = e + await yield_resume_queue.async_q.put(resume_value) + + async def _run_agent_function(self, initial_message: Message) -> None: + task: asyncio.Task | None = None try: async with self._agent.dependency_container( initial_message, self.run_context, self.request_context @@ -459,118 +576,22 @@ async def _run_agent_function(self, initial_message: Message) -> None: self._agent.execute_fn(self.run_context, **dependency_container.user_dependency_args) ) try: - resume_value: RunYieldResume = None - opened_artifacts: set[str] = set() - while not task.done() or yield_queue.async_q.qsize() > 0: - yielded_value = await yield_queue.async_q.get() - - if isinstance(yielded_value, _message.Message): - validate_message(yielded_value) - - self.last_invocation = datetime.now() - - match yielded_value: - case str(text): - await self.task_updater.update_status( - TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message(parts=[Part(text=text)]), - ) - case Part() as part: - await self.task_updater.update_status( - TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message(parts=[part]), - ) - case Message() as message: - await self.task_updater.update_status( - TaskState.TASK_STATE_WORKING, message=self._with_context(message) - ) - case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata): - last_chunk = True - if "_last_chunk" in metadata: - last_chunk = bool(metadata["_last_chunk"]) - del metadata["_last_chunk"] - append = artifact_id in opened_artifacts - if not last_chunk: - opened_artifacts.add(artifact_id) - elif artifact_id in opened_artifacts: - opened_artifacts.remove(artifact_id) - - await self.task_updater.add_artifact( - parts=list(parts), - artifact_id=artifact_id, - name=name, - metadata=dict(metadata), - last_chunk=last_chunk, - append=append, - ) - case TaskStatus( - state=( - TaskState.TASK_STATE_AUTH_REQUIRED | TaskState.TASK_STATE_INPUT_REQUIRED - ) as state, - message=message, - ): - await self.task_updater.update_status(state=state, message=self._with_context(message)) - self._working = False - resume_value = await self.resume_queue.get() - self.resume_queue.task_done() - case TaskStatus(state=state, message=message): - await self.task_updater.update_status(state=state, message=self._with_context(message)) - case TaskStatusUpdateEvent( - status=TaskStatus(state=state, message=message), - metadata=metadata, - ): - await self.task_updater.update_status( - state=state, message=self._with_context(message), metadata=dict(metadata) - ) - case TaskArtifactUpdateEvent( - artifact=Artifact(artifact_id=artifact_id, name=name, metadata=metadata, parts=parts), - append=append, - last_chunk=last_chunk, - ): - await self.task_updater.add_artifact( - parts=list(parts), - artifact_id=artifact_id, - name=name, - metadata=dict(metadata), - append=append, - last_chunk=last_chunk, - ) - case Metadata() as metadata: - await self.task_updater.update_status( - state=TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message(parts=[], metadata=metadata), - ) - case dict() as data: - from google.protobuf.struct_pb2 import Struct, Value - - s = Struct() - s.update(data) - await self.task_updater.update_status( - state=TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message( - parts=[Part(data=Value(struct_value=s))] - ), - ) - case Exception() as ex: - raise ex - case _: - raise ValueError(f"Invalid value yielded from agent: {type(yielded_value)}") - - await yield_resume_queue.async_q.put(resume_value) - - await self.task_updater.complete() - - except (janus.AsyncQueueShutDown, GeneratorExit): - await self.task_updater.complete() + with suppress(janus.AsyncQueueShutDown, GeneratorExit): + await self._agent_loop(task) + await task + final_message = self._accumulator.flush() + await self.task_updater.complete(message=self._prepare_message(final_message)) except Exception as ex: logger.error("Error when executing agent", exc_info=ex) await self.task_updater.failed(get_error_extension_context().server.message(ex)) - await cancel_task(task) except Exception as ex: logger.error("Error when executing agent", exc_info=ex) await self.task_updater.failed(get_error_extension_context().server.message(ex)) finally: self._working = False + if task: + with suppress(Exception): + await cancel_task(task) with suppress(Exception): self._handle_finish() diff --git a/apps/adk-py/src/kagenti_adk/server/constants.py b/apps/adk-py/src/kagenti_adk/server/constants.py index 0fec9ad9..b4fd32e5 100644 --- a/apps/adk-py/src/kagenti_adk/server/constants.py +++ b/apps/adk-py/src/kagenti_adk/server/constants.py @@ -8,9 +8,11 @@ from kagenti_adk.a2a.extensions import BaseExtensionServer from kagenti_adk.a2a.extensions.services.platform import PlatformApiExtensionServer, PlatformApiExtensionSpec +from kagenti_adk.a2a.extensions.streaming import StreamingExtensionServer, StreamingExtensionSpec from kagenti_adk.a2a.extensions.ui.error import ErrorExtensionParams, ErrorExtensionServer, ErrorExtensionSpec DEFAULT_IMPLICIT_EXTENSIONS: Final[dict[str, BaseExtensionServer]] = { + StreamingExtensionSpec.URI: StreamingExtensionServer(StreamingExtensionSpec()), ErrorExtensionSpec.URI: ErrorExtensionServer(ErrorExtensionSpec(ErrorExtensionParams())), PlatformApiExtensionSpec.URI: PlatformApiExtensionServer(PlatformApiExtensionSpec()), } diff --git a/apps/adk-py/src/kagenti_adk/server/context.py b/apps/adk-py/src/kagenti_adk/server/context.py index cb20cdb4..01e85eb2 100644 --- a/apps/adk-py/src/kagenti_adk/server/context.py +++ b/apps/adk-py/src/kagenti_adk/server/context.py @@ -8,7 +8,12 @@ from uuid import UUID import janus -from a2a.types import Artifact, Message, Task +from a2a.types import ( + Artifact, + Message, + Task, +) +from asgiref.sync import async_to_sync from pydantic import BaseModel, PrivateAttr from kagenti_adk.a2a.types import RunYield, RunYieldResume @@ -16,22 +21,26 @@ from kagenti_adk.server.store.context_store import ContextStoreInstance +class RunContextSettings(BaseModel): + strict: bool = False + + class RunContext(BaseModel, arbitrary_types_allowed=True): task_id: str context_id: str current_task: Task | None = None related_tasks: list[Task] | None = None + strict: bool = False # TODO: explain strict mode - what yields will stop message etc. Use in match/case _store: ContextStoreInstance _yield_queue: janus.Queue[RunYield] = PrivateAttr(default_factory=janus.Queue) + _yield_resume_queue: janus.Queue[RunYieldResume | Exception] = PrivateAttr(default_factory=janus.Queue) def __init__(self, _store: ContextStoreInstance, **data): super().__init__(**data) self._store = _store - _yield_resume_queue: janus.Queue[RunYieldResume] = PrivateAttr(default_factory=janus.Queue) - - async def store(self, data: Message | Artifact): + def _prepare_store_data(self, data: Message | Artifact) -> Message | Artifact: if not self._store: raise RuntimeError("Context store is not initialized") if isinstance(data, Message): @@ -39,8 +48,14 @@ async def store(self, data: Message | Artifact): msg.CopyFrom(data) msg.context_id = self.context_id msg.task_id = self.task_id - data = msg - await self._store.store(data) + return msg + return data + + async def store(self, data: Message | Artifact): + await self._store.store(self._prepare_store_data(data)) + + def store_sync(self, data: Message | Artifact): + async_to_sync(self._store.store)(self._prepare_store_data(data)) @overload def load_history(self, load_history_items: Literal[False] = False) -> AsyncGenerator[Message | Artifact, None]: ... @@ -63,11 +78,17 @@ async def delete_history_from_id(self, from_id: UUID) -> None: def yield_sync(self, value: RunYield) -> RunYieldResume: self._yield_queue.sync_q.put(value) - return self._yield_resume_queue.sync_q.get() + resp = self._yield_resume_queue.sync_q.get() + if isinstance(resp, Exception): + raise resp + return resp async def yield_async(self, value: RunYield) -> RunYieldResume: await self._yield_queue.async_q.put(value) - return await self._yield_resume_queue.async_q.get() + resp = await self._yield_resume_queue.async_q.get() + if isinstance(resp, Exception): + raise resp + return resp def shutdown(self) -> None: self._yield_queue.shutdown() diff --git a/apps/adk-py/src/kagenti_adk/server/dependencies.py b/apps/adk-py/src/kagenti_adk/server/dependencies.py index d02a82af..32ece69b 100644 --- a/apps/adk-py/src/kagenti_adk/server/dependencies.py +++ b/apps/adk-py/src/kagenti_adk/server/dependencies.py @@ -17,7 +17,7 @@ from typing_extensions import Doc from kagenti_adk.a2a.extensions.base import BaseExtensionServer, BaseExtensionSpec -from kagenti_adk.server.context import RunContext +from kagenti_adk.server.context import RunContext, RunContextSettings Dependency: TypeAlias = Callable[[Message, RunContext, RequestContext], Any] | BaseExtensionServer[Any, Any] @@ -58,9 +58,36 @@ async def lifespan() -> AsyncIterator[Dependency]: return lifespan() +def _get_param_type_hints(fn: Callable[..., Any]) -> dict[str, Any]: + """Get type hints for function parameters only, skipping the return annotation. + + typing.get_type_hints() evaluates all annotations including return type, + which can fail when annotations use `X | Y` with types that don't support + the `|` operator at runtime (e.g. protobuf classes, factory functions). + """ + try: + return typing.get_type_hints(fn, include_extras=True) + except TypeError: + # Evaluate parameter annotations individually, skipping any that fail + globalns = getattr(fn, "__globals__", {}) + hints: dict[str, Any] = {} + for name, param in inspect.signature(fn).parameters.items(): + ann = param.annotation + if ann is inspect.Parameter.empty: + continue + if isinstance(ann, str): + try: + hints[name] = eval(ann, globalns) # noqa: S307 + except Exception: + hints[name] = ann + else: + hints[name] = ann + return hints + + def extract_dependencies(fn: Callable[..., Any]) -> dict[str, Depends]: sign = inspect.signature(fn) - type_hints = typing.get_type_hints(fn, include_extras=True) + type_hints = _get_param_type_hints(fn) dependencies = {} seen_keys = set() @@ -70,6 +97,11 @@ def process_args(name: str, args: tuple[Any, ...]) -> None: # extension_param: Annotated[some_type, Depends(some_callable)] if isinstance(spec, Depends): dependencies[name] = spec + # extension_param: Annotated[RunContext, RunContextSettings()] + if isinstance(dep_type, RunContext) and isinstance(spec, RunContextSettings): + dependencies[name] = Depends( + lambda _message, run_context, _request_context: run_context.model_copy(update=spec.model_dump()) + ) # extension_param: Annotated[BaseExtensionServer, BaseExtensionSpec()] elif ( isclass(dep_type) and issubclass(dep_type, BaseExtensionServer) and isinstance(spec, BaseExtensionSpec) @@ -114,7 +146,7 @@ def process_args(name: str, args: tuple[Any, ...]) -> None: if reserved_names := {param for param in dependencies if param.startswith("__")}: raise TypeError(f"User-defined dependencies cannot start with double underscore: {reserved_names}") - extension_deps = Counter(dep.extension.spec.URI for dep in dependencies.values() if dep.extension) + extension_deps = Counter(dep.extension.spec.URI for dep in dependencies.values() if dep.extension is not None) if duplicate_uris := {k for k, v in extension_deps.items() if v > 1}: raise TypeError(f"Duplicate extension URIs found in the agent function: {duplicate_uris}") diff --git a/apps/adk-py/src/kagenti_adk/server/exceptions.py b/apps/adk-py/src/kagenti_adk/server/exceptions.py index 7200f7b6..68b91907 100644 --- a/apps/adk-py/src/kagenti_adk/server/exceptions.py +++ b/apps/adk-py/src/kagenti_adk/server/exceptions.py @@ -1,5 +1,10 @@ # Copyright 2026 © IBM Corp. # SPDX-License-Identifier: Apache-2.0 - from __future__ import annotations +from kagenti_adk.a2a.types import RunYield + + +class InvalidYieldError(RuntimeError): + def __init__(self, yielded_value: RunYield): + super().__init__(f"Invalid yield of type: {type(yielded_value)}") diff --git a/apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py b/apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py new file mode 100644 index 00000000..7b2dca9d --- /dev/null +++ b/apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py @@ -0,0 +1,160 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import difflib +import json +from collections.abc import MutableMapping, MutableSequence +from typing import Any + +import jsonpatch +from jsonpatch import DiffBuilder, JsonPatch, PatchOperation + + +class StrInsOperation(PatchOperation): + """ + Inserts text into a string property at a specific position. + Operation format: { "op": "str_ins", "path": "/foo/bar", "pos": 5, "value": "text" } + If "pos" is omitted, it defaults to appending. + """ + + def apply(self, obj: Any) -> Any: + try: + value = self.operation["value"] + except KeyError: + raise jsonpatch.InvalidJsonPatch("The operation does not contain a 'value' member") + + subobj, part = self.pointer.to_last(obj) + + if isinstance(subobj, MutableMapping): + if part not in subobj: + raise jsonpatch.JsonPatchConflict(f"Target path {self.location} does not exist") + current_val = subobj[part] + elif isinstance(subobj, MutableSequence): + try: + part_idx = int(part) # type: ignore [arg-type] + current_val = subobj[part_idx] + except (IndexError, ValueError): + raise jsonpatch.JsonPatchConflict(f"Target path {self.location} does not exist") + else: + raise jsonpatch.JsonPatchConflict(f"Cannot apply str_ins to {type(subobj)}") + + if not isinstance(current_val, str): + raise jsonpatch.JsonPatchConflict(f"Target value at {self.location} is not a string") + + pos = self.operation.get("pos", len(current_val)) + if not isinstance(pos, (int, float)): + raise jsonpatch.InvalidJsonPatch("The operation 'pos' member must be a number") + pos = int(pos) + + if pos < 0 or pos > len(current_val): + raise jsonpatch.JsonPatchConflict( + f"Position {pos} is out of bounds for string of length {len(current_val)}" + ) + + # Insert logic: string slicing + new_val = current_val[:pos] + value + current_val[pos:] + + if isinstance(subobj, MutableMapping): + subobj[part] = new_val + elif isinstance(subobj, MutableSequence): + subobj[int(part)] = new_val # type: ignore [arg-type] + + return obj + + def to_string(self) -> str: + # We need to include 'pos' in the string representation + op_dict = {"op": "str_ins", "path": self.location, "value": self.operation["value"]} + if "pos" in self.operation: + op_dict["pos"] = self.operation["pos"] + return json.dumps(op_dict) + + +class ExtendedDiffBuilder(DiffBuilder): + """ + Extended DiffBuilder that detects string insertions and generates `str_ins` operations. + It uses difflib to find the most efficient string patch (currently focusing on single continuous insertion). + """ + + def _item_replaced(self, path: str, key: Any, item: Any) -> None: + """ + Called when a value is replaced. We check if it's a string modification + that can be represented as a str_ins operation. + """ + # Attempt to retrieve old value using the pointer + # path is the parent path, key is the item key/index + # We construct the full pointer to the item + full_path_str = path + "/" + str(key).replace("~", "~0").replace("/", "~1") + ptr = self.pointer_cls(full_path_str) + + try: + old_value = ptr.resolve(self.src_doc) + except Exception: + # Fallback to standard replace if we can't find old value (shouldn't happen) + super()._item_replaced(path, key, item) + return + + if isinstance(old_value, str) and isinstance(item, str): + # Optimization: Check for simple append first (O(1)ish vs O(N)) + if item.startswith(old_value) and len(item) > len(old_value): + diff = item[len(old_value) :] + self.insert( + StrInsOperation( + {"op": "str_ins", "path": full_path_str, "pos": len(old_value), "value": diff}, + pointer_cls=self.pointer_cls, + ) + ) + return + + # Analyze for arbitrary insertion using difflib + # We look for exactly ONE 'insert' block and everything else 'equal'. + # If there are deletes or replaces or multiple inserts, we fallback to full value replace + # because str_ins only does insertion, not deletion/replacement. + + matcher = difflib.SequenceMatcher(None, old_value, item) + opcodes = matcher.get_opcodes() + + insert_ops = [op for op in opcodes if op[0] == "insert"] + other_ops = [op for op in opcodes if op[0] != "insert"] + + # Condition: + # 1. Must have at least one insert (otherwise equal or delete) + # 2. All other ops must be 'equal' + # 3. Ideally only ONE insert op to keep patch simple (though we could support multiple) + # For streaming, we typically expect one chunk inserted. + + if len(insert_ops) == 1 and all(op[0] == "equal" for op in other_ops): + tag, i1, i2, j1, j2 = insert_ops[0] + inserted_text = item[j1:j2] + + # 'pos' is i1 (index in old string where insertion starts) + + self.insert( + StrInsOperation( + {"op": "str_ins", "path": full_path_str, "pos": i1, "value": inserted_text}, + pointer_cls=self.pointer_cls, + ) + ) + return + + super()._item_replaced(path, key, item) + + +class ExtendedJsonPatch(JsonPatch): + operations = dict(JsonPatch.operations) + operations["str_ins"] = StrInsOperation # type: ignore [assignment] + + +def make_patch(src: Any, dst: Any) -> ExtendedJsonPatch: + """ + Generates a patch using the ExtendedDiffBuilder. + """ + builder = ExtendedDiffBuilder(src, dst, jsonpatch.json.dumps, jsonpatch.JsonPointer) + builder._compare_values("", None, src, dst) + ops = list(builder.execute()) + # Note: jsonpatch.JsonPatch(ops) validates ops but might re-parse if passed as list of dicts? + # Actually JsonPatch init takes list of dicts or list of PatchOperations? + # It takes list of dicts usually. `builder.execute()` yields dicts (operation dicts). + # Wait, `DiffBuilder.execute` yields `PatchOperation.operation` (which is a dict). + # So `ops` is a list of dicts. + return ExtendedJsonPatch(ops) diff --git a/apps/adk-py/src/kagenti_adk/server/utils.py b/apps/adk-py/src/kagenti_adk/server/utils.py index 244f9f4c..08abc617 100644 --- a/apps/adk-py/src/kagenti_adk/server/utils.py +++ b/apps/adk-py/src/kagenti_adk/server/utils.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations - +from kagenti_adk.types import JsonValue +from kagenti_adk.a2a.types import Metadata +from a2a.types import Message import asyncio from asyncio import CancelledError from contextlib import suppress @@ -26,3 +28,38 @@ async def close_queue(queue_manager: QueueManager, queue_name: str, immediate: b if queue := await queue_manager.get(queue_name): await queue.close(immediate=immediate) await queue_manager.close(queue_name) + + +def _merge_recursive(obj: JsonValue, other: JsonValue) -> JsonValue: + if isinstance(obj, dict) and isinstance(other, dict): + merged = {**obj} + for k, v in other.items(): + if k in merged: + merged[k] = _merge_recursive(merged[k], v) + else: + merged[k] = v + return merged + elif isinstance(obj, list) and isinstance(other, list): + return obj + other + else: + return other + + +def merge_metadata(*metadata_items: Metadata) -> Metadata: + result = Metadata() + for m in metadata_items: + for k, v in m.items(): + result[k] = _merge_recursive(result.get(k, {}), v) + return result + + +def merge_messages(*messages: Message) -> Message | None: + if not messages: + return None + merged = Message() + merged.CopyFrom(messages[0]) + for msg in messages[1:]: + merged.parts.extend(msg.parts) + for k, v in msg.metadata.items(): + merged.metadata[k] = v + return merged diff --git a/apps/adk-py/src/kagenti_adk/types.py b/apps/adk-py/src/kagenti_adk/types.py index 16777f21..49d54616 100644 --- a/apps/adk-py/src/kagenti_adk/types.py +++ b/apps/adk-py/src/kagenti_adk/types.py @@ -4,13 +4,15 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, TypeAlias, TypedDict +from typing import TYPE_CHECKING, Required, TypeAlias, TypedDict from a2a.types import SecurityRequirement, SecurityScheme from starlette.authentication import AuthenticationBackend __all__ = [ "JsonDict", + "JsonPatch", + "JsonPatchOp", "JsonValue", "SdkAuthenticationBackend", ] @@ -19,13 +21,24 @@ JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None JsonDict: TypeAlias = dict[str, JsonValue] else: - from typing import Union + from typing import Union # noqa: F401 from typing_extensions import TypeAliasType - JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007 + JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]") +class JsonPatchOp(TypedDict, total=False): + """A single JSON Patch operation (RFC 6902), extended with 'str_ins' from json-crdt-patch.""" + + op: Required[str] + path: Required[str] + value: JsonValue + pos: int # str_ins extension: insertion position + + +JsonPatch: TypeAlias = list[JsonPatchOp] + class A2ASecurity(TypedDict): security_requirements: list[SecurityRequirement] diff --git a/apps/adk-py/tests/conftest.py b/apps/adk-py/tests/conftest.py deleted file mode 100644 index c2a8dadc..00000000 --- a/apps/adk-py/tests/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import async_lru - -# Disable async_lru event loop check in tests -async_lru._LRUCacheWrapper._check_loop = lambda self, loop: None diff --git a/apps/adk-py/tests/e2e/conftest.py b/apps/adk-py/tests/e2e/conftest.py index fe29e31f..9b0d998d 100644 --- a/apps/adk-py/tests/e2e/conftest.py +++ b/apps/adk-py/tests/e2e/conftest.py @@ -12,7 +12,8 @@ import httpx import pytest -from a2a.client import Client, ClientConfig, ClientFactory +from a2a.client import Client, ClientCallContext, ClientConfig, ClientFactory +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import ( AgentCard, Artifact, @@ -26,7 +27,7 @@ from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed from kagenti_adk.a2a.extensions.ui.agent_detail import AgentDetail -from kagenti_adk.a2a.types import AgentArtifact, ArtifactChunk, InputRequired, RunYield, RunYieldResume +from kagenti_adk.a2a.types import AgentArtifact, AgentMessage, ArtifactChunk, InputRequired, RunYield, RunYieldResume from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext from kagenti_adk.server.store.context_store import ContextStore @@ -40,9 +41,19 @@ def get_free_port() -> int: return int(sock.getsockname()[1]) +def make_extension_context(extensions: list[str] | None = None) -> ClientCallContext | None: + """Create a ClientCallContext with extension URIs as service parameters.""" + if not extensions: + return None + return ClientCallContext(service_parameters={HTTP_EXTENSION_HEADER: ",".join(extensions)}) + + @asynccontextmanager async def run_server( - server: Server, port: int, context_store: ContextStore | None = None, task_timeout: timedelta | None = None + server: Server, + port: int, + context_store: ContextStore | None = None, + task_timeout: timedelta | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: tg.create_task( @@ -66,7 +77,9 @@ async def run_server( card_resp = await httpx_client.get(f"{base_url}{AGENT_CARD_WELL_KNOWN_PATH}") card_resp.raise_for_status() card = ParseDict(card_resp.json(), AgentCard(), ignore_unknown_fields=True) - client = ClientFactory(ClientConfig(httpx_client=httpx_client)).create(card=card) + client = ClientFactory(ClientConfig(httpx_client=httpx_client)).create( + card=card, + ) yield server, client finally: server.should_exit = True @@ -78,7 +91,9 @@ def create_server_with_agent(): @asynccontextmanager async def _create_server( - agent_fn, context_store: ContextStore | None = None, task_timeout: timedelta | None = None + agent_fn, + context_store: ContextStore | None = None, + task_timeout: timedelta | None = None, ) -> AsyncIterator[tuple[Server, Client]]: server = Server() server.agent(detail=AgentDetail(interaction_mode="multi-turn"))(agent_fn) @@ -95,7 +110,7 @@ async def _create_server( @pytest.fixture async def echo(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def echo(message: Message, context: RunContext) -> AsyncGenerator[str, Message]: + async def echo(message: Message, context: RunContext) -> AsyncGenerator[AgentMessage, Message]: for part in message.parts: yield part.text @@ -105,7 +120,7 @@ async def echo(message: Message, context: RunContext) -> AsyncGenerator[str, Mes @pytest.fixture async def slow_echo(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def slow_echo(message: Message, context: RunContext) -> AsyncGenerator[str, Message]: + async def slow_echo(message: Message, context: RunContext) -> AsyncGenerator[AgentMessage, Message]: # Slower version with delay for part in message.parts: await asyncio.sleep(1) @@ -117,9 +132,9 @@ async def slow_echo(message: Message, context: RunContext) -> AsyncGenerator[str @pytest.fixture async def awaiter(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | str, Message]: + async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | AgentMessage, Message]: # Agent that requires input - yield "Processing initial message..." + yield AgentMessage(text="Processing initial message...") resume_message = yield InputRequired(text="need input") yield f"Received resume: {resume_message.parts[0].text if resume_message.parts else 'empty'}" @@ -130,9 +145,9 @@ async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskS @pytest.fixture async def awaiter_with_1s_timeout(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | str, Message]: + async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | AgentMessage, Message]: # Agent that requires input - yield "Processing initial message..." + yield AgentMessage(text="Processing initial message...") resume_message = yield InputRequired(text="need input") yield f"Received resume: {resume_message.parts[0].text if resume_message.parts else 'empty'}" @@ -153,7 +168,7 @@ async def failer(message: Message, context: RunContext) -> AsyncGenerator[RunYie @pytest.fixture async def raiser(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def raiser(message: Message, context: RunContext) -> AsyncGenerator[str, Message]: + async def raiser(message: Message, context: RunContext) -> AsyncGenerator[AgentMessage, Message]: # Another failing agent raise RuntimeError("Wrong question buddy!") @@ -163,9 +178,11 @@ async def raiser(message: Message, context: RunContext) -> AsyncGenerator[str, M @pytest.fixture async def artifact_producer(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def artifact_producer(message: Message, context: RunContext) -> AsyncGenerator[str | Artifact, Message]: + async def artifact_producer( + message: Message, context: RunContext + ) -> AsyncGenerator[AgentMessage | Artifact, Message]: # Agent producing artifacts - yield "Processing with artifacts" + yield AgentMessage(text="Processing with artifacts") # Create artifacts with proper parts structure yield AgentArtifact( @@ -195,23 +212,24 @@ async def chunked_artifact_producer( message: Message, context: RunContext ) -> AsyncGenerator[str | Artifact, Message]: # Agent producing chunked artifacts - yield "Processing chunked artifacts" + yield AgentMessage(text="Processing chunked artifacts") - # Create a large text artifact in chunks + # Create a large text artifact in chunks using ArtifactChunk with shared artifact_id + shared_id = "chunked-artifact-1" yield ArtifactChunk( - artifact_id="1", + artifact_id=shared_id, name="large-file.txt", parts=[Part(text="This is the first chunk of data.\n")], ) yield ArtifactChunk( - artifact_id="1", + artifact_id=shared_id, name="large-file.txt", parts=[Part(text="This is the second chunk of data.\n")], ) yield ArtifactChunk( - artifact_id="1", + artifact_id=shared_id, name="large-file.txt", parts=[Part(text="This is the final chunk of data.\n")], last_chunk=True, diff --git a/apps/adk-py/tests/e2e/test_streaming.py b/apps/adk-py/tests/e2e/test_streaming.py new file mode 100644 index 00000000..118efaeb --- /dev/null +++ b/apps/adk-py/tests/e2e/test_streaming.py @@ -0,0 +1,476 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import AsyncGenerator, AsyncIterator + +import pytest +from a2a.client import Client +from a2a.client.helpers import create_text_message_object +from a2a.types import ( + Artifact, + Message, + Part, + Role, + SendMessageRequest, + StreamResponse, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from google.protobuf.json_format import MessageToDict + +from kagenti_adk.a2a.extensions.streaming import ( + ArtifactDelta, + MetadataDelta, + PartDelta, + StateChange, + StreamingExtensionClient, + StreamingExtensionSpec, + TextDelta, +) +from kagenti_adk.a2a.types import AgentMessage, ArtifactChunk, InputRequired, Metadata, RunYield +from kagenti_adk.server import Server +from kagenti_adk.server.context import RunContext +from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch +from conftest import make_extension_context + +pytestmark = pytest.mark.e2e + +STREAMING_URI = StreamingExtensionSpec.URI +STREAMING_CONTEXT = make_extension_context([STREAMING_URI]) + + +def extract_streaming_patches(events: list) -> list[dict]: + """Extract streaming patches from collected client events (flattened).""" + patches = [] + for event in events: + match event: + case (StreamResponse(status_update=TaskStatusUpdateEvent(metadata=metadata)), _) if metadata: + meta_dict = MessageToDict(metadata) + if STREAMING_URI in meta_dict: + patch_list = meta_dict[STREAMING_URI].get("message_update") + if isinstance(patch_list, list): + patches.extend(patch_list) + return patches + + +def apply_patches(patches: list[dict]) -> dict: + """Apply a sequence of streaming patches to build a message object.""" + return ExtendedJsonPatch(patches).apply({}) + + +def extract_status_events(events: list) -> list[TaskStatusUpdateEvent]: + """Extract all status update events.""" + status_events = [] + for event in events: + match event: + case (StreamResponse(status_update=TaskStatusUpdateEvent() as update), _): + if MessageToDict(update.status): + status_events.append(update) + return status_events + + +# --- Fixtures --- + + +@pytest.fixture +async def streaming_string_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + async def string_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield "Hello" + yield " beautiful" + yield " world" + + async with create_server_with_agent( + string_yielder + ) as (server, client): + yield server, client + + +@pytest.fixture +async def streaming_part_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + async def part_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield Part(text="explicit part") + + async with create_server_with_agent( + part_yielder + ) as (server, client): + yield server, client + + +@pytest.fixture +async def streaming_mixed_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + async def mixed_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield "text1" + yield "text2" + yield AgentMessage(text="final") + + async with create_server_with_agent( + mixed_yielder + ) as (server, client): + yield server, client + + +@pytest.fixture +async def no_streaming_string_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + """Same agent as streaming_string_agent but WITHOUT streaming extension.""" + + async def string_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield "Hello" + yield " beautiful" + yield " world" + + async with create_server_with_agent(string_yielder) as (server, client): + yield server, client + + +# --- Tests --- + + +async def test_string_yields_produce_streaming_patches(streaming_string_agent): + """Verify applying streaming patches builds the same message as the wire.""" + _, client = streaming_string_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + patches = extract_streaming_patches(events) + assert len(patches) >= 2 + + # Apply all patches to build the message + built = apply_patches(patches) + assert built["parts"][0]["text"] == "Hello beautiful world" + + # Compare to the COMPLETED wire message + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + wire_parts = [MessageToDict(p) for p in completed[0].status.message.parts] + assert wire_parts == built["parts"] + + +async def test_part_yield_produces_streaming_patch(streaming_part_agent): + """Verify applying streaming patches for Part yields builds the correct message.""" + _, client = streaming_part_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + patches = extract_streaming_patches(events) + assert len(patches) >= 1 + + built = apply_patches(patches) + assert built["parts"][0]["text"] == "explicit part" + + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + wire_parts = [MessageToDict(p) for p in completed[0].status.message.parts] + assert wire_parts == built["parts"] + + +async def test_completion_flushes_accumulated_message(streaming_string_agent): + """Verify that the completed event contains all accumulated text.""" + _, client = streaming_string_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + + final_message = completed[0].status.message + assert final_message.parts[0].text == "Hello beautiful world" + + +async def test_mixed_yields_message_flush(streaming_mixed_agent): + """Verify accumulated text is flushed before explicit AgentMessage.""" + _, client = streaming_mixed_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + # Extract all messages from WORKING status events + working_messages = [] + for event in events: + match event: + case (StreamResponse(status_update=TaskStatusUpdateEvent(status=TaskStatus( + state=TaskState.TASK_STATE_WORKING, message=Message(message_id=mid) + ))), _) if mid: + working_messages.append(event[0].status_update.status.message) + + # Should have at least the explicit AgentMessage as a WORKING event + # The accumulated "text1text2" is flushed as a draft merged into the AgentMessage + assert len(working_messages) >= 1 + + # Verify the completed event contains both messages in history + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + + +async def test_no_streaming_patches_without_extension(no_streaming_string_agent): + """Verify no streaming metadata when extension is not activated by client.""" + _, client = no_streaming_string_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi"))): + events.append(event) + + patches = extract_streaming_patches(events) + assert len(patches) == 0 + + # But the final message should still have the accumulated text + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + assert completed[0].status.message.parts[0].text == "Hello beautiful world" + + +async def test_complex_accumulator_state_machine(create_server_with_agent): + """Comprehensive test exercising all accumulator state transitions in a single scenario. + + Agent yield sequence and expected state machine transitions: + + yield "Hello" Base → TextPart (add /parts/-) + yield " world" TextPart → TextPart (str_ins) + yield Part("[separator]") TextPart → Message (flush text part, add explicit part) + yield {"score": 42} Message → Message (dict→Part, add /parts/-) + yield Metadata(ext: [a]) Message → Message (add /metadata) + yield Metadata(ext: [b]) Message → Message (replace /metadata, arrays concatenated) + yield AgentMessage("ckpt") Message → Base (flush draft, dispatch WORKING message) + yield ArtifactChunk(...) handled outside accumulator (artifact event) + yield "post-artifact" Base → TextPart (new accumulation after reset) + yield InputRequired(...) TextPart → Base (flush draft, INPUT_REQUIRED) + --- resume --- + yield "you said: ..." Base → TextPart (new accumulation) + implicit flush → COMPLETED + """ + + async def complex_agent(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + # Phase 1: Text streaming (Base → TextPart → TextPart) + yield "Hello" + yield " world" + + # Phase 2: Part after text (TextPart → Message; flushes text, adds explicit part) + yield Part(text="[separator]") + + # Phase 3: Dict yield (Message → Message; dict converted to data Part) + yield {"score": 42} + + # Phase 4: Metadata accumulation with array concatenation + yield Metadata({"ext://test": [{"ref": "a"}]}) + yield Metadata({"ext://test": [{"ref": "b"}]}) + + # Phase 5: Explicit Message flushes everything accumulated so far + # Draft: parts=[Text("Hello world"), Part("[separator]"), DataPart({score:42})], + # metadata={ext://test: [{ref:a},{ref:b}]} + # Merged with AgentMessage: draft parts + [Part("checkpoint")] + yield AgentMessage(text="checkpoint") + + # Phase 6: Artifact (bypasses accumulator entirely) + yield ArtifactChunk( + artifact_id="art-1", name="data.txt", + parts=[Part(text="artifact body")], last_chunk=True, + ) + + # Phase 7: New accumulation cycle after full reset + yield "post-artifact" + + # Phase 8: InputRequired flushes accumulated text, pauses for user input + # Draft: parts=[Text("post-artifact")] + # Merged with InputRequired message: [Text("post-artifact"), Text("what next?")] + resume_msg = yield InputRequired(text="what next?") + + # Phase 9: After resume — new accumulation, then implicit flush on return + yield f"you said: {resume_msg.parts[0].text}" + + async with create_server_with_agent(complex_agent) as (_, client): + # --- First send: initial message --- + events_1 = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="go")), context=STREAMING_CONTEXT): + events_1.append(event) + + all_patches = extract_streaming_patches(events_1) + status_events = extract_status_events(events_1) + + # Split patches into accumulation cycles by root replace boundaries + cycles: list[list[dict]] = [] + for patch in all_patches: + if patch["op"] == "replace" and patch["path"] == "": + cycles.append([]) + cycles[-1].append(patch) + assert len(cycles) == 2 # cycle 1: before AgentMessage, cycle 2: after artifact + + # -- Cycle 1: Apply patches to build the draft -- + # Yields: "Hello", " world", Part("[separator]"), {"score": 42}, Metadata x2 + draft_1 = apply_patches(cycles[0]) + assert draft_1["parts"][0]["text"] == "Hello world" + assert draft_1["parts"][1]["text"] == "[separator]" + assert draft_1["parts"][2]["data"] == {"score": 42.0} + assert draft_1["metadata"]["ext://test"] == [{"ref": "a"}, {"ref": "b"}] + + # The WORKING wire message = merge(draft_1, AgentMessage("checkpoint")) + # Draft parts are a prefix of the wire message parts + working = [ + e for e in status_events + if e.status.state == TaskState.TASK_STATE_WORKING and e.status.message.message_id + ] + assert len(working) == 1 + wire_msg = working[0].status.message + wire_parts = [MessageToDict(p) for p in wire_msg.parts] + wire_meta = MessageToDict(wire_msg.metadata) + + # Draft's 3 parts are the prefix; AgentMessage adds "checkpoint" as the 4th + assert wire_parts[:3] == draft_1["parts"] + assert wire_parts[3]["text"] == "checkpoint" + assert wire_meta == draft_1["metadata"] + + # -- Artifact event (bypasses accumulator) -- + artifact_events = [ + event for event in events_1 + if isinstance(event[0], StreamResponse) and event[0].artifact_update.artifact.artifact_id + ] + assert len(artifact_events) == 1 + assert artifact_events[0][0].artifact_update.artifact.name == "data.txt" + + # -- Cycle 2: Apply patches to build the draft -- + # Yields: "post-artifact" (then InputRequired flushes) + draft_2 = apply_patches(cycles[1]) + assert draft_2["parts"][0]["text"] == "post-artifact" + + # The INPUT_REQUIRED wire message = merge(draft_2, InputRequired("what next?")) + input_required = [ + e for e in status_events + if e.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + ] + assert len(input_required) == 1 + ir_parts = [MessageToDict(p) for p in input_required[0].status.message.parts] + assert ir_parts[:1] == draft_2["parts"] + assert ir_parts[1]["text"] == "what next?" + + # --- Second send: resume --- + task_id = events_1[-1][1].id + resume = create_text_message_object(content="hello again") + resume.task_id = task_id + + events_2 = [] + async for event in client.send_message(SendMessageRequest(message=resume), context=STREAMING_CONTEXT): + events_2.append(event) + + status_events_2 = extract_status_events(events_2) + patches_2 = extract_streaming_patches(events_2) + + # Cycle 3: Apply patches — should match the COMPLETED wire message exactly + draft_3 = apply_patches(patches_2) + completed = [e for e in status_events_2 if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + wire_completed_parts = [MessageToDict(p) for p in completed[0].status.message.parts] + assert wire_completed_parts == draft_3["parts"] + + +# --- StreamingExtensionClient tests --- + + +def _make_streaming_client() -> StreamingExtensionClient: + return StreamingExtensionClient(StreamingExtensionSpec()) + + +async def test_streaming_client_text_deltas(streaming_string_agent): + """Verify StreamingExtensionClient emits TextDelta for streamed text chunks.""" + _, client = streaming_string_agent + streaming = _make_streaming_client() + + text_deltas = [] + state_changes = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + match delta: + case TextDelta() as td: + text_deltas.append(td.delta) + case PartDelta(): + pass + case StateChange() as sc: + state_changes.append(sc) + + # Should have text deltas for " beautiful" and " world" (first chunk is PartDelta from root replace) + assert len(text_deltas) >= 2 + + # Verify final state change is COMPLETED + completed = [sc for sc in state_changes if sc.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + # The completed message should have been reconciled (already streamed) + assert completed[0].message is not None + + +async def test_streaming_client_part_delta(streaming_part_agent): + """Verify StreamingExtensionClient emits PartDelta for explicit Part yields.""" + _, client = streaming_part_agent + streaming = _make_streaming_client() + + part_deltas = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + match delta: + case PartDelta() as pd: + part_deltas.append(pd) + case _: + pass + + assert len(part_deltas) >= 1 + assert part_deltas[0].part["text"] == "explicit part" + + +async def test_streaming_client_without_extension(no_streaming_string_agent): + """Verify StreamingExtensionClient works without streaming extension (decompose full messages).""" + _, client = no_streaming_string_agent + streaming = _make_streaming_client() + + part_deltas = [] + state_changes = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")))): + match delta: + case PartDelta() as pd: + part_deltas.append(pd) + case StateChange() as sc: + state_changes.append(sc) + case _: + pass + + # Without streaming, the completed message should be decomposed into PartDelta + completed = [sc for sc in state_changes if sc.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + # The message parts should appear as PartDelta events + assert any(pd.part.get("text") == "Hello beautiful world" for pd in part_deltas) + + +async def test_streaming_client_reconciles_streamed_messages(streaming_mixed_agent): + """Verify that messages already streamed via patches are properly reconciled.""" + _, client = streaming_mixed_agent + streaming = _make_streaming_client() + + all_deltas = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + all_deltas.append(delta) + + # Should have deltas from streaming (text1, text2) and then the explicit AgentMessage + text_deltas = [d for d in all_deltas if isinstance(d, TextDelta)] + part_deltas = [d for d in all_deltas if isinstance(d, PartDelta)] + state_changes = [d for d in all_deltas if isinstance(d, StateChange)] + + # The streaming patches produce text/part deltas for "text1" and "text2" + assert len(text_deltas) + len(part_deltas) >= 1 + + # Should end with COMPLETED state + assert any(sc.state == TaskState.TASK_STATE_COMPLETED for sc in state_changes) + + +async def test_streaming_client_message_id_tracking(streaming_string_agent): + """Verify message_id is tracked from streaming patches.""" + _, client = streaming_string_agent + streaming = _make_streaming_client() + + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + pass + + # After stream completes, the draft should have been used + # message_id should have been set during streaming + # (it gets cleared after reconciliation, so we check the completed state) diff --git a/apps/adk-py/tests/e2e/test_yields.py b/apps/adk-py/tests/e2e/test_yields.py index 7b1f347c..143797ae 100644 --- a/apps/adk-py/tests/e2e/test_yields.py +++ b/apps/adk-py/tests/e2e/test_yields.py @@ -236,11 +236,13 @@ async def test_sync_function_with_context_agent(sync_function_with_context_agent assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED - # Should have intermediate yield and final result + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "first sync yield" in messages - assert "sync_function_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "first sync yield" in text + assert "sync_function_with_context_agent: hello" in text async def test_sync_generator_agent(sync_generator_agent): @@ -252,10 +254,13 @@ async def test_sync_generator_agent(sync_generator_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator yield 1" in messages - assert "sync_generator yield 2" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "sync_generator yield 1" in text + assert "sync_generator yield 2" in text async def test_sync_generator_with_context_agent(sync_generator_with_context_agent): @@ -267,12 +272,15 @@ async def test_sync_generator_with_context_agent(sync_generator_with_context_age assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator_with_context yield 1" in messages - assert "sync_generator_with_context context yield" in messages - assert "sync_generator_with_context yield 2" in messages - assert "sync_generator_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "sync_generator_with_context yield 1" in text + assert "sync_generator_with_context context yield" in text + assert "sync_generator_with_context yield 2" in text + assert "sync_generator_with_context_agent: hello" in text async def test_async_function_agent(async_function_agent): @@ -297,10 +305,13 @@ async def test_async_function_with_context_agent(async_function_with_context_age assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "first async yield" in messages - assert "async_function_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "first async yield" in text + assert "async_function_with_context_agent: hello" in text async def test_async_generator_agent(async_generator_agent): @@ -312,11 +323,14 @@ async def test_async_generator_agent(async_generator_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator yield 1" in messages - assert "async_generator yield 2" in messages - assert "async_generator_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "async_generator yield 1" in text + assert "async_generator yield 2" in text + assert "async_generator_agent: hello" in text async def test_async_generator_with_context_agent(async_generator_with_context_agent): @@ -328,12 +342,15 @@ async def test_async_generator_with_context_agent(async_generator_with_context_a assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator_with_context yield 1" in messages - assert "async_generator_with_context context yield" in messages - assert "async_generator_with_context yield 2" in messages - assert "async_generator_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "async_generator_with_context yield 1" in text + assert "async_generator_with_context context yield" in text + assert "async_generator_with_context yield 2" in text + assert "async_generator_with_context_agent: hello" in text async def test_sync_function_resume_agent(sync_function_resume_agent): @@ -369,9 +386,10 @@ async def test_sync_generator_resume_agent(sync_generator_resume_agent): assert initial_task is not None assert initial_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + # The "starting" string is flushed as a message before the InputRequired status # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in initial_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator_resume_agent: starting" in messages + agent_messages = [msg for msg in initial_task.history if msg.role == Role.ROLE_AGENT] + assert any("sync_generator_resume_agent: starting" in msg.parts[0].text for msg in agent_messages) # Resume with additional data resume_message = create_text_message_object(content="resume data") @@ -383,8 +401,8 @@ async def test_sync_generator_resume_agent(sync_generator_resume_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator_resume_agent: received resume data" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert any("sync_generator_resume_agent: received resume data" in msg.parts[0].text for msg in agent_messages) async def test_async_function_resume_agent(async_function_resume_agent): @@ -421,9 +439,10 @@ async def test_async_generator_resume_agent(async_generator_resume_agent): assert initial_task is not None assert initial_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + # The "starting" string is flushed as a message before the InputRequired status # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in initial_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator_resume_agent: starting" in messages + agent_messages = [msg for msg in initial_task.history if msg.role == Role.ROLE_AGENT] + assert any("async_generator_resume_agent: starting" in msg.parts[0].text for msg in agent_messages) # Resume with additional data resume_message = create_text_message_object(content="resume data") @@ -435,8 +454,8 @@ async def test_async_generator_resume_agent(async_generator_resume_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator_resume_agent: received resume data" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert any("async_generator_resume_agent: received resume data" in msg.parts[0].text for msg in agent_messages) async def test_sync_function_streaming(sync_function_agent): @@ -457,7 +476,7 @@ async def test_sync_function_streaming(sync_function_agent): async def test_sync_generator_streaming(sync_generator_agent): - """Test synchronous generator agent with streaming to see intermediate yields.""" + """Test synchronous generator agent with streaming events.""" _, client = sync_generator_agent events = [] async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hello"))): @@ -472,13 +491,14 @@ async def test_sync_generator_streaming(sync_generator_agent): assert len(status_events) > 0 assert status_events[-1].status.state == TaskState.TASK_STATE_COMPLETED - # Should see multiple working state messages for each yield - working_events = [e for e in status_events if e.status.state == TaskState.TASK_STATE_WORKING] - assert len(working_events) >= 3 # At least 3 yields from the generator + # Without the streaming extension activated by the client, partial updates are not sent. + # The final completed message contains the accumulated result. + completed = status_events[-1] + assert completed.status.message.parts[0].text async def test_async_generator_streaming(async_generator_agent): - """Test asynchronous generator agent with streaming to see intermediate yields.""" + """Test asynchronous generator agent with streaming events.""" _, client = async_generator_agent events = [] async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hello"))): @@ -492,13 +512,15 @@ async def test_async_generator_streaming(async_generator_agent): assert len(status_events) > 0 assert status_events[-1].status.state == TaskState.TASK_STATE_COMPLETED - # Should see multiple working state messages for each yield - working_events = [e for e in status_events if e.status.state == TaskState.TASK_STATE_WORKING] - assert len(working_events) >= 2 # At least 2 yields from the generator + # Without the streaming extension activated by the client, partial updates are not sent. + # The final completed message contains the accumulated result. + completed = status_events[-1] + assert completed.status.message.parts[0].text async def test_yield_dict_vs_metadata(create_server_with_agent): async def yielder_of_meta_data() -> AsyncIterator[RunYield]: + # dict → Part(data=...), Metadata, and AgentMessage are accumulated into one message yield {"data": "this should be datapart"} yield Metadata({"metadata": "this should be metadata"}) yield AgentMessage( @@ -513,26 +535,24 @@ async def yielder_of_meta_data() -> AsyncIterator[RunYield]: assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # dict+Metadata are accumulated, then flushed as draft when AgentMessage arrives + # The final message merges the draft (data part + metadata) with the AgentMessage (metadata) + # pyrefly: ignore [missing-attribute, not-iterable] + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + merged = agent_messages[0] # pyrefly: ignore [missing-attribute, unsupported-operation] - assert MessageToDict(final_task.history[0].parts[0].data) == {"data": "this should be datapart"} + assert MessageToDict(merged.parts[0].data) == {"data": "this should be datapart"} + # Metadata from both Metadata yield and AgentMessage are merged # pyrefly: ignore [unsupported-operation] - assert MessageToDict(final_task.history[1].metadata) == {"metadata": "this should be metadata"} - # pyrefly: ignore [unsupported-operation] - assert MessageToDict(final_task.history[2].metadata) == { - "metadata": "this class still behaves as dict", - "metadata2": "and can be used in union", - } - # pyrefly: ignore [unsupported-operation] - assert not final_task.history[0].metadata - # pyrefly: ignore [unsupported-operation] - assert not final_task.history[1].parts - # pyrefly: ignore [unsupported-operation] - assert not final_task.history[2].parts + merged_meta = MessageToDict(merged.metadata) + assert merged_meta["metadata"] == "this class still behaves as dict" + assert merged_meta["metadata2"] == "and can be used in union" async def test_yield_of_all_types(create_server_with_agent): async def yielder_of_all_types_agent(message: Message, context: RunContext) -> AsyncIterator[RunYield]: - """Synchronous function agent that returns a string directly.""" + """Agent that yields all supported types.""" text_part = Part(text="text") message = AgentMessage(parts=[text_part], role=Role.ROLE_AGENT, message_id=str(uuid.uuid4())) yield message @@ -571,5 +591,10 @@ async def yielder_of_all_types_agent(message: Message, context: RunContext) -> A _, ) if artifact_id: artifact_cnt += 1 - assert message_cnt == 9 + # With streaming accumulation: + # - Message yield → 1 message + # - text_part accumulated, then TaskStatus flushes → 1 merged message + # - Parts accumulated, then TaskStatusUpdateEvent flushes → 1 merged message + # - str/dict/Metadata accumulated, then completion flushes → 1 message + assert message_cnt == 4 assert artifact_cnt == 2 diff --git a/apps/adk-py/tests/test_merge_utils.py b/apps/adk-py/tests/test_merge_utils.py new file mode 100644 index 00000000..015943a4 --- /dev/null +++ b/apps/adk-py/tests/test_merge_utils.py @@ -0,0 +1,60 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC + +from kagenti_adk.server.utils import _merge_recursive, merge_metadata +from kagenti_adk.a2a.types import Metadata +import pytest + + +def test_scalar_merge(): + assert _merge_recursive(1, 2) == 2 + assert _merge_recursive("a", "b") == "b" + assert _merge_recursive(True, False) == False + assert _merge_recursive(None, 1) == 1 + + +def test_list_merge(): + assert _merge_recursive([1, 2], [3, 4]) == [1, 2, 3, 4] + assert _merge_recursive([], [1]) == [1] + assert _merge_recursive([1], []) == [1] + + +def test_dict_merge_simple(): + a = {"x": 1} + b = {"y": 2} + expected = {"x": 1, "y": 2} + assert _merge_recursive(a, b) == expected + + +def test_dict_merge_overwrite(): + a = {"x": 1} + b = {"x": 2} + expected = {"x": 2} + assert _merge_recursive(a, b) == expected + + +def test_dict_merge_recursive(): + a = {"x": 1, "y": {"a": 1}} + b = {"y": {"b": 2}, "z": 3} + expected = {"x": 1, "y": {"a": 1, "b": 2}, "z": 3} + assert _merge_recursive(a, b) == expected + + +def test_dict_merge_nested_list(): + a = {"x": [1]} + b = {"x": [2]} + expected = {"x": [1, 2]} + assert _merge_recursive(a, b) == expected + + +def test_merge_metadata(): + m1 = Metadata(foo="bar") + m2 = Metadata(baz="qux") + merged = merge_metadata(m1, m2) + assert merged == Metadata(foo="bar", baz="qux") + + +def test_merge_metadata_recursive(): + m1 = Metadata(nested={"a": 1}) + m2 = Metadata(nested={"b": 2}) + merged = merge_metadata(m1, m2) + assert merged == Metadata(nested={"a": 1, "b": 2}) diff --git a/apps/adk-py/tests/unit/server/test_accumulator.py b/apps/adk-py/tests/unit/server/test_accumulator.py new file mode 100644 index 00000000..0cd9c979 --- /dev/null +++ b/apps/adk-py/tests/unit/server/test_accumulator.py @@ -0,0 +1,292 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +from a2a.types import Part, Role, TaskState, TaskStatus, TaskStatusUpdateEvent + +from kagenti_adk.a2a.types import AgentMessage, Metadata +from kagenti_adk.server.accumulator import MessageAccumulator, MessageContext, TextPartContext + +pytestmark = pytest.mark.unit + + +# --- State transition tests --- + + +class TestStateTransitions: + def test_initial_state_is_base_level(self): + acc = MessageAccumulator() + assert isinstance(acc.active_context, MessageAccumulator) + + def test_string_enters_text_part_context(self): + acc = MessageAccumulator() + result = acc.process("hello") + assert result.accumulated is True + assert isinstance(acc.active_context, TextPartContext) + + def test_consecutive_strings_stay_in_text_part(self): + acc = MessageAccumulator() + acc.process("a") + result = acc.process("b") + assert result.accumulated is True + assert isinstance(acc.active_context, TextPartContext) + + def test_part_enters_message_context(self): + acc = MessageAccumulator() + result = acc.process(Part(text="x")) + assert result.accumulated is True + assert isinstance(acc.active_context, MessageContext) + + def test_metadata_enters_message_context(self): + acc = MessageAccumulator() + result = acc.process(Metadata({"key": "val"})) + assert result.accumulated is True + assert isinstance(acc.active_context, MessageContext) + + def test_part_after_string_transitions_to_message_context(self): + acc = MessageAccumulator() + acc.process("text chunk") + result = acc.process(Part(text="explicit part")) + assert result.accumulated is True + assert isinstance(acc.active_context, MessageContext) + # The text chunk was built into a Part and added to the MessageContext + ctx = acc.active_context + assert len(ctx.parts) == 2 + assert ctx.parts[0].text == "text chunk" + assert ctx.parts[1].text == "explicit part" + + def test_message_passthrough_at_base_level(self): + acc = MessageAccumulator() + msg = AgentMessage(text="hello") + result = acc.process(msg) + assert result.accumulated is False + assert result.draft is None + assert isinstance(acc.active_context, MessageAccumulator) + + def test_task_status_passthrough_at_base_level(self): + acc = MessageAccumulator() + status = TaskStatus(state=TaskState.TASK_STATE_WORKING) + result = acc.process(status) + assert result.accumulated is False + assert result.draft is None + assert isinstance(acc.active_context, MessageAccumulator) + + def test_task_status_update_event_passthrough_at_base_level(self): + acc = MessageAccumulator() + event = TaskStatusUpdateEvent( + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + task_id="t1", + context_id="c1", + ) + result = acc.process(event) + assert result.accumulated is False + assert isinstance(acc.active_context, MessageAccumulator) + + def test_message_flushes_accumulated_text(self): + acc = MessageAccumulator() + acc.process("a") + acc.process("b") + msg = AgentMessage(text="final") + result = acc.process(msg) + assert result.accumulated is False + assert result.draft is not None + assert result.draft.parts[0].text == "ab" + assert isinstance(acc.active_context, MessageAccumulator) + + def test_task_status_flushes_accumulated_parts(self): + acc = MessageAccumulator() + acc.process(Part(text="hello")) + status = TaskStatus(state=TaskState.TASK_STATE_WORKING) + result = acc.process(status) + assert result.accumulated is False + assert result.draft is not None + assert result.draft.parts[0].text == "hello" + assert isinstance(acc.active_context, MessageAccumulator) + + def test_message_flushes_accumulated_text_and_parts(self): + acc = MessageAccumulator() + acc.process("text chunk") + acc.process(Part(text="explicit part")) + acc.process(Metadata({"key": "val"})) + msg = AgentMessage(text="final") + result = acc.process(msg) + assert result.accumulated is False + assert result.draft is not None + # Draft should contain the accumulated text part and the explicit part + assert len(result.draft.parts) == 2 + assert result.draft.parts[0].text == "text chunk" + assert result.draft.parts[1].text == "explicit part" + assert isinstance(acc.active_context, MessageAccumulator) + + def test_string_after_part_enters_text_part_context(self): + acc = MessageAccumulator() + acc.process(Part(text="first")) + assert isinstance(acc.active_context, MessageContext) + acc.process("streaming text") + assert isinstance(acc.active_context, TextPartContext) + + def test_input_required_flushes_text(self): + acc = MessageAccumulator() + acc.process("thinking...") + status = TaskStatus(state=TaskState.TASK_STATE_INPUT_REQUIRED) + result = acc.process(status) + assert result.accumulated is False + assert result.draft is not None + assert result.draft.parts[0].text == "thinking..." + + +# --- Patch verification tests --- + + +class TestPatchOutput: + def test_first_string_produces_replace_root_patch(self): + acc = MessageAccumulator() + result = acc.process("Hello") + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [{"text": "Hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_subsequent_string_produces_str_ins_patch(self): + acc = MessageAccumulator() + acc.process("Hello") + result = acc.process(" world") + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "str_ins" + assert patch["path"] == "/parts/0/text" + assert patch["pos"] == 5 + assert patch["value"] == " world" + + def test_str_ins_path_uses_correct_index_after_parts(self): + acc = MessageAccumulator() + acc.process(Part(text="first")) + acc.process(Part(text="second")) + acc.process("stream") # add patch at index 2 + result = acc.process("ing") # str_ins at index 2 + assert result.patch is not None + assert result.patch[0]["path"] == "/parts/2/text" + + def test_first_part_produces_replace_root_patch(self): + acc = MessageAccumulator() + result = acc.process(Part(text="hello")) + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] + assert value["parts"] == [{"text": "hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_first_metadata_produces_replace_root_patch(self): + acc = MessageAccumulator() + result = acc.process(Metadata({"key": "val"})) + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] + assert value["parts"] == [] # type: ignore[index] + assert value["metadata"] == {"key": "val"} # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_second_metadata_produces_incremental_patches(self): + acc = MessageAccumulator() + acc.process(Metadata({"ext://test": [{"ref": "a"}]})) + result = acc.process(Metadata({"ext://test": [{"ref": "b"}]})) + assert result.patch is not None + # Should produce incremental add patch, not a full replace + assert len(result.patch) >= 1 + # The patches should target /metadata/... paths + for op in result.patch: + assert op["path"].startswith("/metadata/") + # Apply all patches to verify correctness + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + draft = {"parts": [], "metadata": {"ext://test": [{"ref": "a"}]}} + draft = ExtendedJsonPatch(result.patch).apply(draft) + assert draft["metadata"]["ext://test"] == [{"ref": "a"}, {"ref": "b"}] + + def test_message_id_propagated_in_all_accumulated_results(self): + acc = MessageAccumulator() + r1 = acc.process("Hello") + assert r1.message_id is not None + r2 = acc.process(" world") + assert r2.message_id == r1.message_id + r3 = acc.process(Part(text="part")) + assert r3.message_id == r1.message_id + r4 = acc.process(Metadata({"k": "v"})) + assert r4.message_id == r1.message_id + + def test_passthrough_has_no_patch(self): + acc = MessageAccumulator() + result = acc.process(AgentMessage(text="hello")) + assert result.patch is None + + +# --- Flush tests --- + + +class TestFlush: + def test_flush_from_text_part_returns_message(self): + acc = MessageAccumulator() + acc.process("hello") + acc.process(" world") + msg = acc.flush() + assert msg is not None + assert msg.role == Role.ROLE_AGENT + assert msg.parts[0].text == "hello world" + + def test_flush_from_message_context_returns_message(self): + acc = MessageAccumulator() + acc.process(Part(text="a")) + acc.process(Part(text="b")) + msg = acc.flush() + assert msg is not None + assert len(msg.parts) == 2 + assert msg.parts[0].text == "a" + assert msg.parts[1].text == "b" + + def test_flush_from_base_level_returns_none(self): + acc = MessageAccumulator() + assert acc.flush() is None + + def test_flush_resets_to_base_level(self): + acc = MessageAccumulator() + acc.process("hello") + assert isinstance(acc.active_context, TextPartContext) + acc.flush() + assert isinstance(acc.active_context, MessageAccumulator) + + def test_flush_after_passthrough_returns_none(self): + acc = MessageAccumulator() + acc.process("hello") + acc.process(AgentMessage(text="explicit")) # flushes internally + assert isinstance(acc.active_context, MessageAccumulator) + assert acc.flush() is None + + def test_flush_with_metadata(self): + acc = MessageAccumulator() + acc.process(Part(text="content")) + acc.process(Metadata({"key": "value"})) + msg = acc.flush() + assert msg is not None + assert msg.parts[0].text == "content" + from google.protobuf.json_format import MessageToDict + + assert MessageToDict(msg.metadata) == {"key": "value"} + + def test_double_flush_returns_none(self): + acc = MessageAccumulator() + acc.process("hello") + msg = acc.flush() + assert msg is not None + assert acc.flush() is None diff --git a/apps/adk-py/tests/unit/server/test_context.py b/apps/adk-py/tests/unit/server/test_context.py new file mode 100644 index 00000000..b31a5966 --- /dev/null +++ b/apps/adk-py/tests/unit/server/test_context.py @@ -0,0 +1,188 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +from a2a.types import Part, Role +from google.protobuf.json_format import MessageToDict + +from kagenti_adk.a2a.types import Metadata +from kagenti_adk.server.accumulator import MessageContext, TextPartContext + +pytestmark = pytest.mark.unit + + +# --- TextPartContext --- + + +class TestTextPartContext: + def _make(self, part_index: int = 0) -> TextPartContext: + return TextPartContext(message_context=MessageContext(), part_index=part_index) + + def test_first_chunk_produces_replace_root_patch(self): + ctx = self._make() + patches = ctx.add_chunk("Hello") + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [{"text": "Hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_subsequent_chunk_produces_str_ins(self): + ctx = self._make() + ctx.add_chunk("Hello") + patches = ctx.add_chunk(" world") + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "str_ins" + assert patch["path"] == "/parts/0/text" + assert patch["value"] == " world" + assert patch["pos"] == 5 + + def test_pos_advances_by_chunk_length(self): + ctx = self._make() + assert ctx.pos == 0 + ctx.add_chunk("abc") + assert ctx.pos == 3 + ctx.add_chunk("de") + assert ctx.pos == 5 + ctx.add_chunk("f") + assert ctx.pos == 6 + + def test_build_concatenates_chunks(self): + ctx = self._make() + ctx.add_chunk("Hello") + ctx.add_chunk(" ") + ctx.add_chunk("world") + part = ctx.build() + assert part.text == "Hello world" + + def test_str_ins_uses_part_index(self): + ctx = self._make(part_index=3) + ctx.add_chunk("Hello") + patches = ctx.add_chunk(" world") + assert patches[0]["path"] == "/parts/3/text" + + def test_build_empty_chunks(self): + ctx = self._make() + part = ctx.build() + assert part.text == "" + + +# --- MessageContext --- + + +class TestMessageContext: + def test_first_add_part_returns_replace_root_patch(self): + ctx = MessageContext() + patches = ctx.add_part(Part(text="hello")) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [{"text": "hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_second_add_part_returns_add_patch(self): + ctx = MessageContext() + ctx.add_part(Part(text="first")) + patches = ctx.add_part(Part(text="second")) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "add" + assert patch["path"] == "/parts/-" + assert patch["value"]["text"] == "second" + + def test_add_part_accumulates(self): + ctx = MessageContext() + ctx.add_part(Part(text="a")) + ctx.add_part(Part(text="b")) + assert len(ctx.parts) == 2 + assert ctx.parts[0].text == "a" + assert ctx.parts[1].text == "b" + + def test_first_add_metadata_returns_replace_root_patch(self): + ctx = MessageContext() + patches = ctx.add_metadata(Metadata({"key": "value"})) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [] # type: ignore[index] + assert value["metadata"] == {"key": "value"} # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_add_metadata_after_part_returns_add_patch(self): + ctx = MessageContext() + ctx.add_part(Part(text="hello")) + patches = ctx.add_metadata(Metadata({"key": "value"})) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "add" + assert patch["path"] == "/metadata" + assert patch["value"] == {"key": "value"} + + def test_add_metadata_deep_merges(self): + ctx = MessageContext() + ctx.add_metadata(Metadata({"a": 1})) + ctx.add_metadata(Metadata({"b": 2})) + assert ctx.metadata == {"a": 1, "b": 2} + + def test_add_metadata_second_call_returns_incremental_patches(self): + ctx = MessageContext() + ctx.add_metadata(Metadata({"a": 1})) + patches = ctx.add_metadata(Metadata({"b": 2})) + assert len(patches) >= 1 + # Should produce an add for the new key, not a full replace + assert all(op["path"].startswith("/metadata/") for op in patches) + # Verify the patches are correct by applying them + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + draft = {"metadata": {"a": 1}} + draft = ExtendedJsonPatch(patches).apply(draft) + assert draft["metadata"] == {"a": 1, "b": 2} + + def test_add_metadata_deep_merges_nested(self): + ctx = MessageContext() + ctx.add_metadata(Metadata({"ext": {"x": 1}})) + ctx.add_metadata(Metadata({"ext": {"y": 2}})) + assert ctx.metadata == {"ext": {"x": 1, "y": 2}} + + def test_add_metadata_concatenates_arrays(self): + """Extensions using arrays (e.g. citations, trajectory) should accumulate via concatenation.""" + uri = "https://example.com/ext/v1" + ctx = MessageContext() + ctx.add_metadata(Metadata({uri: [{"title": "a"}]})) + ctx.add_metadata(Metadata({uri: [{"title": "b"}]})) + assert ctx.metadata == {uri: [{"title": "a"}, {"title": "b"}]} + + def test_build_creates_agent_message(self): + ctx = MessageContext() + ctx.add_part(Part(text="hello")) + ctx.add_metadata(Metadata({"key": "value"})) + msg = ctx.build() + assert msg.role == Role.ROLE_AGENT + assert msg.message_id == str(ctx.message_id) + assert len(msg.parts) == 1 + assert msg.parts[0].text == "hello" + assert MessageToDict(msg.metadata) == {"key": "value"} + + def test_build_without_metadata(self): + ctx = MessageContext() + ctx.add_part(Part(text="hello")) + msg = ctx.build() + assert msg.role == Role.ROLE_AGENT + assert len(msg.parts) == 1 + assert not MessageToDict(msg.metadata) + + def test_build_with_multiple_parts(self): + ctx = MessageContext() + ctx.add_part(Part(text="a")) + ctx.add_part(Part(text="b")) + ctx.add_part(Part(text="c")) + msg = ctx.build() + assert [p.text for p in msg.parts] == ["a", "b", "c"] diff --git a/apps/adk-py/tests/unit/server/test_jsonpatch_ext.py b/apps/adk-py/tests/unit/server/test_jsonpatch_ext.py new file mode 100644 index 00000000..b8a81857 --- /dev/null +++ b/apps/adk-py/tests/unit/server/test_jsonpatch_ext.py @@ -0,0 +1,78 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC + +import jsonpatch +import pytest + +from kagenti_adk.server.jsonpatch_ext import StrInsOperation, make_patch + + +def test_str_ins_append(): + """Test optimized append detection""" + src = {"text": "Hello"} + dst = {"text": "Hello World"} + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "str_ins" + assert ops[0]["path"] == "/text" + assert ops[0]["pos"] == 5 + assert ops[0]["value"] == " World" + + # Test apply + assert patch.apply(src) == dst + + +def test_str_ins_middle(): + """Test middle insertion detection via difflib""" + src = {"text": "Hello World"} + dst = {"text": "Hello beautiful World"} + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "str_ins" + assert ops[0]["path"] == "/text" + assert ops[0]["value"] == "beautiful " + assert ops[0]["pos"] == 6 + + assert patch.apply(src) == dst + + +def test_str_ins_prepend(): + """Test prepend detection""" + src = {"text": "World"} + dst = {"text": "Hello World"} + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "str_ins" + assert ops[0]["pos"] == 0 + assert ops[0]["value"] == "Hello " + + assert patch.apply(src) == dst + + +def test_str_ins_complex_fallback(): + """Test that multiple changes fallback to replace""" + src = {"text": "Hello World"} + dst = {"text": "Hi World!"} # Changed Start and End + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "replace" + + assert patch.apply(src) == dst + + +def test_str_ins_explicit_apply(): + """Test manual StrInsOperation application""" + obj = {"foo": ["bar"]} + op = StrInsOperation({"op": "str_ins", "path": "/foo/0", "pos": 1, "value": "az"}) + res = op.apply(obj) + assert res["foo"][0] == "bazar" + + +def test_str_ins_out_of_bounds(): + obj = {"text": "foo"} + op = StrInsOperation({"op": "str_ins", "path": "/text", "pos": 100, "value": "bar"}) + with pytest.raises(jsonpatch.JsonPatchConflict): + op.apply(obj) diff --git a/apps/adk-py/tests/unit/test_agent_detail_population.py b/apps/adk-py/tests/unit/test_agent_detail_population.py index a01179bd..89cf8b36 100644 --- a/apps/adk-py/tests/unit/test_agent_detail_population.py +++ b/apps/adk-py/tests/unit/test_agent_detail_population.py @@ -42,8 +42,11 @@ def test_agent_fn(): # Check tools populated from skills assert "tools" in params tools = params["tools"] + # pyrefly: ignore [bad-argument-type] assert len(tools) == 1 + # pyrefly: ignore [unsupported-operation, bad-index] assert tools[0]["name"] == "test_skill" + # pyrefly: ignore [unsupported-operation, bad-index] assert tools[0]["description"] == "A test skill" # Check user_greeting populated from description @@ -88,7 +91,9 @@ def test_agent_fn(): # Tools should still be populated because custom_detail.tools was None assert "tools" in params tools = params["tools"] + # pyrefly: ignore [bad-argument-type] assert len(tools) == 1 + # pyrefly: ignore [unsupported-operation, bad-index] assert tools[0]["name"] == "test_skill" # Greeting should be custom diff --git a/apps/adk-py/tests/unit/test_dependencies.py b/apps/adk-py/tests/unit/test_dependencies.py index 53c353db..362b1075 100644 --- a/apps/adk-py/tests/unit/test_dependencies.py +++ b/apps/adk-py/tests/unit/test_dependencies.py @@ -7,18 +7,26 @@ import pytest -from kagenti_adk.a2a.extensions import CitationExtensionServer, CitationExtensionSpec +from kagenti_adk.a2a.extensions import ( + CitationExtensionServer, + CitationExtensionSpec, + TrajectoryExtensionServer, + TrajectoryExtensionSpec, +) +from kagenti_adk.a2a.extensions.streaming import StreamingExtensionServer, StreamingExtensionSpec from kagenti_adk.server.dependencies import extract_dependencies +pytestmark = pytest.mark.unit + class MyExtensions(TypedDict): a: Annotated[CitationExtensionServer, CitationExtensionSpec()] - b: Annotated[CitationExtensionServer, CitationExtensionSpec()] - c: Annotated[CitationExtensionServer, CitationExtensionSpec()] + b: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()] + c: Annotated[StreamingExtensionServer, StreamingExtensionSpec()] class MyExtensionsComplex(TypedDict): - b: Annotated[CitationExtensionServer, CitationExtensionSpec()] + b: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()] @pytest.mark.unit @@ -29,7 +37,6 @@ def agent(a: Annotated[CitationExtensionServer, CitationExtensionSpec()]) -> Non assert extract_dependencies(agent).keys() == {"a"} -@pytest.mark.unit def test_extract_dependencies_extra_parameters() -> None: def agent(a: Annotated[CitationExtensionServer, CitationExtensionSpec()], b: bool) -> None: pass @@ -48,8 +55,8 @@ def agent(**kwargs: Unpack[MyExtensions]) -> None: assert extract_dependencies(agent).keys() == {"a", "b", "c"} -@pytest.mark.unit def test_extract_dependencies_complex() -> None: + def agent( a: Annotated[CitationExtensionServer, CitationExtensionSpec()], **kwargs: Unpack[MyExtensionsComplex], diff --git a/apps/adk-py/uv.lock b/apps/adk-py/uv.lock index 98d1744e..18af3e37 100644 --- a/apps/adk-py/uv.lock +++ b/apps/adk-py/uv.lock @@ -1119,6 +1119,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonref" version = "1.1.0" @@ -1162,6 +1183,7 @@ source = { editable = "." } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -1169,6 +1191,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -1198,6 +1221,7 @@ dev = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -1205,6 +1229,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-server/pyproject.toml b/apps/adk-server/pyproject.toml index d813fcbb..8d03076a 100644 --- a/apps/adk-server/pyproject.toml +++ b/apps/adk-server/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "opentelemetry-instrumentation-openai>=0.52.1", "opentelemetry-instrumentation-sqlalchemy>=0.60b1", "aws-bedrock-token-generator>=1.1.0", + "jsonpatch>=1.33", ] [dependency-groups] diff --git a/apps/adk-server/src/adk_server/api/routes/a2a.py b/apps/adk-server/src/adk_server/api/routes/a2a.py index cb390345..b14d15c5 100644 --- a/apps/adk-server/src/adk_server/api/routes/a2a.py +++ b/apps/adk-server/src/adk_server/api/routes/a2a.py @@ -100,7 +100,7 @@ async def get_agent_card( card_copy = create_proxy_agent_card( provider.agent_card, provider_id=provider.id, request=request, configuration=configuration ) - return MessageToDict(card_copy, preserving_proto_field_name=True) + return MessageToDict(card_copy) @router.post("/{provider_id}") diff --git a/apps/adk-server/tests/e2e/agents/test_platform_extensions.py b/apps/adk-server/tests/e2e/agents/test_platform_extensions.py index 54bbb147..5c6822c3 100644 --- a/apps/adk-server/tests/e2e/agents/test_platform_extensions.py +++ b/apps/adk-server/tests/e2e/agents/test_platform_extensions.py @@ -8,8 +8,8 @@ from uuid import uuid4 import pytest -from a2a.client import Client, create_text_message_object -from a2a.types import SendMessageRequest, Message, Role, TaskState +from a2a.client import Client +from a2a.types import Message, Role, SendMessageRequest, TaskState from kagenti_adk.a2a.extensions.services.platform import ( PlatformApiExtensionClient, PlatformApiExtensionServer, @@ -90,15 +90,13 @@ async def test_platform_api_extension(file_reader_writer_factory, permissions, s else: assert task.status.state == TaskState.TASK_STATE_COMPLETED, f"Fail: {task.status.message.parts[0].text}" - # check that first message is the content of the first_file - first_message_text = task.history[0].parts[0].text - assert first_message_text == "01234" - - second_message_text = task.history[1].parts[0].text - assert second_message_text == "56789" + # accumulator combines consecutive string chunks into a single text part, + # so we get one message with text + file parts + msg = task.history[0] + assert msg.parts[0].text == "0123456789" # check that the agent uploaded a new file with correct context_id as content - async with load_file(task.history[2].parts[0]) as file: + async with load_file(msg.parts[1]) as file: assert file.text == context.id @@ -126,4 +124,3 @@ async def test_self_registration(self_registration_agent, subtests): assert provider.state == "online" assert "self_registration_agent" in provider.source - diff --git a/apps/adk-server/uv.lock b/apps/adk-server/uv.lock index 970842d6..dd304e58 100644 --- a/apps/adk-server/uv.lock +++ b/apps/adk-server/uv.lock @@ -46,6 +46,7 @@ dependencies = [ { name = "httpx" }, { name = "ibm-watsonx-ai" }, { name = "ijson" }, + { name = "jsonpatch" }, { name = "kink" }, { name = "kr8s" }, { name = "limits", extra = ["async-redis"] }, @@ -106,6 +107,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "ibm-watsonx-ai", specifier = ">=1.3.28" }, { name = "ijson", specifier = ">=3.4.0.post0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "kink", specifier = ">=0.8.1" }, { name = "kr8s", specifier = ">=0.20.7" }, { name = "limits", extras = ["async-redis"], specifier = ">=5.3.0" }, @@ -1203,6 +1205,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -1250,6 +1273,7 @@ source = { editable = "../adk-py" } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -1257,6 +1281,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -1276,6 +1301,7 @@ dependencies = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -1283,6 +1309,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-ts/src/client/a2a/extensions/index.ts b/apps/adk-ts/src/client/a2a/extensions/index.ts index dc565d55..e0f4cfda 100644 --- a/apps/adk-ts/src/client/a2a/extensions/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/index.ts @@ -19,4 +19,5 @@ export * from './ui/citation'; export * from './ui/error'; export * from './ui/form-request'; export * from './ui/settings'; +export * from './ui/streaming'; export * from './ui/trajectory'; diff --git a/apps/adk-ts/src/client/a2a/extensions/types.ts b/apps/adk-ts/src/client/a2a/extensions/types.ts index 730bd2e8..36713071 100644 --- a/apps/adk-ts/src/client/a2a/extensions/types.ts +++ b/apps/adk-ts/src/client/a2a/extensions/types.ts @@ -17,4 +17,5 @@ export * from './ui/canvas/types'; export * from './ui/citation/types'; export * from './ui/error/types'; export * from './ui/settings/types'; +export * from './ui/streaming/types'; export * from './ui/trajectory/types'; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts index c18c34de..e88140a7 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts @@ -6,12 +6,12 @@ import z from 'zod'; import type { A2AUiExtension } from '../../../../core/extensions/types'; -import { citationMetadataSchema } from './schemas'; -import type { CitationMetadata } from './types'; +import { citationSchema } from './schemas'; +import type { Citation } from './types'; export const CITATION_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/citation/v1'; -export const citationExtension: A2AUiExtension = { +export const citationExtension: A2AUiExtension = { getUri: () => CITATION_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [CITATION_EXTENSION_URI]: citationMetadataSchema }).partial(), + getMessageMetadataSchema: () => z.object({ [CITATION_EXTENSION_URI]: z.array(citationSchema) }).partial(), }; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts new file mode 100644 index 00000000..d1186faf --- /dev/null +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts @@ -0,0 +1,20 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +import type { A2AUiExtension } from '../../../../core/extensions/types'; +import { streamingMetadataSchema } from './schemas'; +import type { StreamingMetadata } from './types'; + +export type { StreamingMetadata, StreamingPatch } from './types'; +export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1'; + +export const streamingExtension: A2AUiExtension = { + getUri: () => STREAMING_EXTENSION_URI, + getMessageMetadataSchema: () => z.object({ [STREAMING_EXTENSION_URI]: streamingMetadataSchema }).partial(), +}; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts new file mode 100644 index 00000000..3cce4408 --- /dev/null +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +export const streamingPatchSchema = z.object({ + op: z.string(), + path: z.string(), + value: z.unknown().optional(), + pos: z.number().optional(), // for str_ins +}); + +export const streamingMetadataSchema = z.object({ + message_update: z.array(streamingPatchSchema).optional(), + message_id: z.string().optional(), +}); diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts new file mode 100644 index 00000000..aa10398f --- /dev/null +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts @@ -0,0 +1,12 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type z from 'zod'; + +import type { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export type StreamingPatch = z.infer; + +export type StreamingMetadata = z.infer; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts index d18b230d..7ab8aac8 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts @@ -11,7 +11,7 @@ import type { TrajectoryMetadata } from './types'; export const TRAJECTORY_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/trajectory/v1'; -export const trajectoryExtension: A2AUiExtension = { +export const trajectoryExtension: A2AUiExtension = { getUri: () => TRAJECTORY_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [TRAJECTORY_EXTENSION_URI]: trajectoryMetadataSchema }).partial(), + getMessageMetadataSchema: () => z.object({ [TRAJECTORY_EXTENSION_URI]: z.array(trajectoryMetadataSchema) }).partial(), }; diff --git a/apps/adk-ui/src/api/a2a/agent-card.ts b/apps/adk-ui/src/api/a2a/agent-card.ts index 5bbc46ac..2ab80317 100644 --- a/apps/adk-ui/src/api/a2a/agent-card.ts +++ b/apps/adk-ui/src/api/a2a/agent-card.ts @@ -19,7 +19,8 @@ export async function getAgentClient(providerId: string, token: string): Promise const agentCard = await fetchAgentCard(agentCardUrl, fetchImpl); - return createA2AClient({ endpointUrl, agentCard, fetchImpl }); + const extensions = agentCard.capabilities?.extensions?.map((ext) => ext.uri).filter(Boolean) as string[]; + return createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions }); } async function clientFetch(input: RequestInfo, init?: RequestInit) { diff --git a/apps/adk-ui/src/api/a2a/client.ts b/apps/adk-ui/src/api/a2a/client.ts index fbe1b796..0b3afe09 100644 --- a/apps/adk-ui/src/api/a2a/client.ts +++ b/apps/adk-ui/src/api/a2a/client.ts @@ -17,6 +17,7 @@ import { getAgentClient } from './agent-card'; import { AGENT_ERROR_MESSAGE } from './constants'; import type { A2AClient } from './jsonrpc-client'; import { processArtifactMetadata, processMessageMetadata, processParts } from './part-processors'; +import { applyPatches, extractStreamingPatches } from './streaming'; import type { ChatResult, TaskStatusUpdateResultWithTaskId } from './types'; import { type ChatParams, type ChatRun, RunResultType } from './types'; import { createUserMessage, extractErrorExtension } from './utils'; @@ -124,6 +125,7 @@ export const buildA2AClient = async ({ const messageSubject = new Subject>(); let taskId: undefined | TaskId = initialTaskId; + const streamingDraft: Record = {}; const iterateOverStream = async () => { const agentCardMetadata = await resolveAgentCardMetadata(fulfillments); @@ -153,6 +155,28 @@ export const buildA2AClient = async ({ .with({ statusUpdate: P.nonNullable }, ({ statusUpdate }) => { taskId = statusUpdate.taskId; + // Check for streaming patches in metadata + const patches = extractStreamingPatches( + statusUpdate.metadata as Record | undefined, + ); + + if (patches && taskId) { + // Apply patches to draft and emit as a replace update + applyPatches(streamingDraft, patches); + const draftParts = (streamingDraft.parts as Array>) ?? []; + const uiParts: UIMessagePart[] = draftParts + .map((part): UIMessagePart | null => { + if (typeof part.text === 'string') { + return { kind: UIMessagePartKind.Text, id: 'streaming-text', text: part.text } as UITextPart; + } + return null; + }) + .filter((p): p is UIMessagePart => p !== null); + + messageSubject.next({ type: RunResultType.Parts, parts: uiParts, taskId, replace: true }); + return; + } + handleTaskStatusUpdate(statusUpdate).forEach((result) => { if (!taskId) { throw new Error(`Illegal State - taskId missing on status-update event`); @@ -166,10 +190,17 @@ export const buildA2AClient = async ({ const parts: (UIMessagePart | UIGenericPart)[] = handleStatusUpdate(statusUpdate, onStatusUpdate); + // On final message, clear the streaming draft and replace the + // streamed text so it is not duplicated by the complete message. + const wasStreaming = Object.keys(streamingDraft).length > 0; + if (wasStreaming) { + Object.keys(streamingDraft).forEach((key) => delete streamingDraft[key]); + } + if (!taskId) { throw new Error(`Illegal State - taskId missing on status-update event`); } - messageSubject.next({ type: RunResultType.Parts, parts, taskId }); + messageSubject.next({ type: RunResultType.Parts, parts, taskId, replace: wasStreaming }); }) .with({ artifactUpdate: P.nonNullable }, ({ artifactUpdate }) => { taskId = artifactUpdate.taskId; diff --git a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts index 5c901b4c..9083bd49 100644 --- a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts +++ b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts @@ -23,13 +23,19 @@ interface CreateClientParams { endpointUrl: string; agentCard: AgentCard; fetchImpl: typeof fetch; + extensions?: string[]; } -export function createA2AClient({ endpointUrl, agentCard, fetchImpl }: CreateClientParams): A2AClient { +export function createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions }: CreateClientParams): A2AClient { + const requestHeaders: Record = { 'Content-Type': 'application/json' }; + if (extensions?.length) { + requestHeaders['X-A2A-Extensions'] = extensions.join(','); + } + async function jsonRpcRequest(method: string, params: Record) { const response = await fetchImpl(endpointUrl, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: requestHeaders, body: JSON.stringify({ jsonrpc: '2.0', id: uuid(), @@ -59,7 +65,7 @@ export function createA2AClient({ endpointUrl, agentCard, fetchImpl }: CreateCli async *sendMessageStream(params) { const response = await fetchImpl(endpointUrl, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: requestHeaders, body: JSON.stringify({ jsonrpc: '2.0', id: uuid(), diff --git a/apps/adk-ui/src/api/a2a/part-processors.ts b/apps/adk-ui/src/api/a2a/part-processors.ts index f04aa591..e181cf59 100644 --- a/apps/adk-ui/src/api/a2a/part-processors.ts +++ b/apps/adk-ui/src/api/a2a/part-processors.ts @@ -21,12 +21,14 @@ import { export function processMessageMetadata(message: Message): UIMessagePart[] { const trajectory = extractTrajectory(message.metadata); - const citations = extractCitation(message.metadata)?.citations; + const citations = extractCitation(message.metadata); const parts: UIMessagePart[] = []; if (trajectory) { - parts.push(createTrajectoryPart(trajectory)); + for (const item of trajectory) { + parts.push(createTrajectoryPart(item)); + } } if (citations) { const sourceParts = citations.map((citation) => createSourcePart(citation, message.taskId)).filter(isNotNull); @@ -38,7 +40,7 @@ export function processMessageMetadata(message: Message): UIMessagePart[] { } export function processArtifactMetadata(artifact: Artifact, taskId: string): UISourcePart[] { - const citations = extractCitation(artifact.metadata)?.citations; + const citations = extractCitation(artifact.metadata); if (!citations) { return []; diff --git a/apps/adk-ui/src/api/a2a/streaming.ts b/apps/adk-ui/src/api/a2a/streaming.ts new file mode 100644 index 00000000..bdf03d1d --- /dev/null +++ b/apps/adk-ui/src/api/a2a/streaming.ts @@ -0,0 +1,136 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { StreamingMetadata, StreamingPatch } from '@kagenti/adk'; +import { STREAMING_EXTENSION_URI } from '@kagenti/adk'; + +/** + * Extract streaming patches from status update metadata. + */ +export function extractStreamingPatches( + metadata: Record | undefined | null, +): StreamingPatch[] | null { + if (!metadata) return null; + const streamingData = metadata[STREAMING_EXTENSION_URI] as StreamingMetadata | undefined; + if (!streamingData?.message_update?.length) return null; + return streamingData.message_update; +} + +/** + * Resolve a JSON pointer path to get/set a value in a nested object. + * Returns [parent, key] for the target location. + */ +function resolvePath(obj: Record, path: string): [Record, string] { + if (path === '' || path === '/') return [obj, '']; + + const parts = path.split('/').filter(Boolean); + let current: unknown = obj; + + for (let i = 0; i < parts.length - 1; i++) { + const key = parts[i]; + if (Array.isArray(current)) { + current = current[Number(key)]; + } else if (current && typeof current === 'object') { + current = (current as Record)[key]; + } + } + + return [current as Record, parts[parts.length - 1]]; +} + +function getByPath(obj: Record, path: string): unknown { + if (path === '' || path === '/') return obj; + const parts = path.split('/').filter(Boolean); + let current: unknown = obj; + for (const key of parts) { + if (Array.isArray(current)) { + current = current[Number(key)]; + } else if (current && typeof current === 'object') { + current = (current as Record)[key]; + } else { + return undefined; + } + } + return current; +} + +/** Clone only objects; primitives (strings, numbers, booleans) are immutable and don't need cloning. */ +function cloneValue(v: T): T { + return typeof v === 'object' && v !== null ? structuredClone(v) : v; +} + +/** + * Apply a single streaming patch to a draft message object. + * Supports: replace, add, str_ins (custom string insertion). + */ +function applyPatch(draft: Record, patch: StreamingPatch): void { + const { op, path, value } = patch; + + if (op === 'replace') { + if (path === '' || path === '/') { + // Root replace — overwrite entire draft + Object.keys(draft).forEach((key) => delete draft[key]); + if (value && typeof value === 'object') { + Object.assign(draft, cloneValue(value)); + } + } else { + const [parent, key] = resolvePath(draft, path); + if (parent && typeof parent === 'object') { + if (Array.isArray(parent)) { + parent[Number(key)] = cloneValue(value); + } else { + parent[key] = cloneValue(value); + } + } + } + } else if (op === 'add') { + if (path.endsWith('/-')) { + // Append to array + const arrayPath = path.slice(0, -2); + const arr = getByPath(draft, arrayPath); + if (Array.isArray(arr)) { + arr.push(cloneValue(value)); + } + } else { + const [parent, key] = resolvePath(draft, path); + if (parent && typeof parent === 'object') { + if (Array.isArray(parent)) { + parent.splice(Number(key), 0, cloneValue(value)); + } else { + parent[key] = cloneValue(value); + } + } + } + } else if (op === 'str_ins') { + // Custom operation: insert string at position + const pos = patch.pos ?? 0; + const current = getByPath(draft, path); + if (typeof current === 'string' && typeof value === 'string') { + const newValue = current.slice(0, pos) + value + current.slice(pos); + const [parent, key] = resolvePath(draft, path); + if (parent && typeof parent === 'object') { + if (Array.isArray(parent)) { + parent[Number(key)] = newValue; + } else { + parent[key] = newValue; + } + } + } + } +} + +/** + * Apply an array of streaming patches to a draft message object. + * Mutates the draft in place and returns it. + */ +export function applyPatches( + draft: Record, + patches: StreamingPatch[], +): Record { + for (const patch of patches) { + applyPatch(draft, patch); + } + return draft; +} diff --git a/apps/adk-ui/src/api/a2a/types.ts b/apps/adk-ui/src/api/a2a/types.ts index 89aa5dff..d7077eb1 100644 --- a/apps/adk-ui/src/api/a2a/types.ts +++ b/apps/adk-ui/src/api/a2a/types.ts @@ -18,6 +18,8 @@ export interface PartsResult { type: RunResultType.Parts; taskId: TaskId; parts: Array; + /** When true, replace all current parts instead of appending (used for streaming updates). */ + replace?: boolean; } export type TaskStatusUpdateResultWithTaskId = TaskStatusUpdateResult & { @@ -37,7 +39,9 @@ export interface ChatParams { export interface ChatRun { taskId?: TaskId; done: Promise; - subscribe: (fn: (data: { parts: (UIMessagePart | UIGenericPart)[]; taskId: TaskId }) => void) => () => void; + subscribe: ( + fn: (data: { parts: (UIMessagePart | UIGenericPart)[]; taskId: TaskId; replace?: boolean }) => void, + ) => () => void; cancel: () => Promise; } diff --git a/apps/adk-ui/src/api/a2a/utils.ts b/apps/adk-ui/src/api/a2a/utils.ts index 4e520ddf..303ee976 100644 --- a/apps/adk-ui/src/api/a2a/utils.ts +++ b/apps/adk-ui/src/api/a2a/utils.ts @@ -9,6 +9,7 @@ import { citationExtension, errorExtension, extractUiExtensionData, + streamingExtension, trajectoryExtension, type TrajectoryMetadata, } from '@kagenti/adk'; @@ -32,6 +33,7 @@ import { PLATFORM_FILE_CONTENT_URL_BASE } from './constants'; export const extractCitation = extractUiExtensionData(citationExtension); export const extractTrajectory = extractUiExtensionData(trajectoryExtension); export const extractErrorExtension = extractUiExtensionData(errorExtension); +export const extractStreamingMessage = extractUiExtensionData(streamingExtension); export function convertMessageParts(uiParts: UIMessagePart[]): Part[] { const parts: Part[] = uiParts diff --git a/apps/adk-ui/src/api/adk-client.ts b/apps/adk-ui/src/api/adk-client.ts index 750fc15f..82792178 100644 --- a/apps/adk-ui/src/api/adk-client.ts +++ b/apps/adk-ui/src/api/adk-client.ts @@ -9,8 +9,6 @@ import { ensureToken } from '#app/(auth)/rsc.tsx'; import { runtimeConfig } from '#contexts/App/runtime-config.ts'; import { getBaseUrl } from '#utils/api/getBaseUrl.ts'; -import { getProxyHeaders } from './utils'; - function buildAuthenticatedAdkClient() { const { isAuthEnabled } = runtimeConfig; const baseUrl = getBaseUrl(); @@ -26,19 +24,6 @@ function buildAuthenticatedAdkClient() { } } - const isServer = typeof window === 'undefined'; - - if (isServer) { - const { headers } = await import('next/headers'); - const { forwarded, forwardedHost, forwardedFor, forwardedProto } = await getProxyHeaders(await headers()); - - request.headers.set('forwarded', forwarded); - - if (forwardedHost) request.headers.set('x-forwarded-host', forwardedHost); - if (forwardedProto) request.headers.set('x-forwarded-proto', forwardedProto); - if (forwardedFor) request.headers.set('x-forwarded-for', forwardedFor); - } - const response = await fetch(request); return response; diff --git a/apps/adk-ui/src/app/(auth)/auth.ts b/apps/adk-ui/src/app/(auth)/auth.ts index a464dbf6..b8755bc2 100644 --- a/apps/adk-ui/src/app/(auth)/auth.ts +++ b/apps/adk-ui/src/app/(auth)/auth.ts @@ -60,7 +60,7 @@ export function getProvider(): ProviderWithId | null { const issuer = process.env.OIDC_PROVIDER_ISSUER; const externalIssuer = process.env.OIDC_PROVIDER_EXTERNAL_ISSUER; - if (!name || !id || !clientId || !clientSecret || !issuer) { + if (!name || !id || !clientId || !issuer) { throw new Error( 'Missing OIDC provider configuration. Set OIDC_PROVIDER_NAME, OIDC_PROVIDER_ID, OIDC_PROVIDER_CLIENT_ID, OIDC_PROVIDER_CLIENT_SECRET, and OIDC_PROVIDER_ISSUER.', ); diff --git a/apps/adk-ui/src/modules/platform-context/constants.ts b/apps/adk-ui/src/modules/platform-context/constants.ts index abe8243a..15364f35 100644 --- a/apps/adk-ui/src/modules/platform-context/constants.ts +++ b/apps/adk-ui/src/modules/platform-context/constants.ts @@ -16,7 +16,7 @@ export const contextTokenPermissionsDefaults: DeepRequired) { pendingRun.current = run; let isFirstIteration = true; - pendingSubscription.current = run.subscribe(({ parts, taskId: responseTaskId }) => { + pendingSubscription.current = run.subscribe(({ parts, taskId: responseTaskId, replace }) => { if (isFirstIteration) { queryClient.invalidateQueries({ queryKey: contextKeys.lists() }); } updateCurrentAgentMessage((message) => { message.taskId = responseTaskId; - }); - parts.forEach((part) => { - updateCurrentAgentMessage((message) => { - const updatedParts = addMessagePart(part, message); - message.parts = updatedParts; - }); + if (replace) { + // Streaming update: replace text parts with latest draft + const nonTextParts = message.parts.filter((p) => p.kind !== UIMessagePartKind.Text); + message.parts = [...parts, ...nonTextParts]; + } else { + for (const part of parts) { + message.parts = addMessagePart(part, message); + } + } }); isFirstIteration = false; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts new file mode 100644 index 00000000..d1186faf --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts @@ -0,0 +1,20 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +import type { A2AUiExtension } from '../../../../core/extensions/types'; +import { streamingMetadataSchema } from './schemas'; +import type { StreamingMetadata } from './types'; + +export type { StreamingMetadata, StreamingPatch } from './types'; +export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1'; + +export const streamingExtension: A2AUiExtension = { + getUri: () => STREAMING_EXTENSION_URI, + getMessageMetadataSchema: () => z.object({ [STREAMING_EXTENSION_URI]: streamingMetadataSchema }).partial(), +}; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts new file mode 100644 index 00000000..3cce4408 --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +export const streamingPatchSchema = z.object({ + op: z.string(), + path: z.string(), + value: z.unknown().optional(), + pos: z.number().optional(), // for str_ins +}); + +export const streamingMetadataSchema = z.object({ + message_update: z.array(streamingPatchSchema).optional(), + message_id: z.string().optional(), +}); diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts new file mode 100644 index 00000000..aa10398f --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts @@ -0,0 +1,12 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type z from 'zod'; + +import type { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export type StreamingPatch = z.infer; + +export type StreamingMetadata = z.infer; From 9cc29489db93d728460b10e9c8edfef2aa5c9dd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Bula=CC=81nek?= Date: Fri, 27 Mar 2026 16:43:20 +0100 Subject: [PATCH 2/4] fix(sdk): review fixes for streaming extension PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Petr Bulánek --- .../kagenti_adk/a2a/extensions/streaming.py | 2 +- .../a2a/extensions/ui/streaming/index.ts | 4 +- apps/adk-ui/src/api/a2a/client.ts | 6 +- apps/adk-ui/src/api/a2a/streaming.ts | 23 +- apps/adk-ui/src/app/(auth)/auth.ts | 2 +- .../agents/api/queries/useListAgents.ts | 10 +- .../components/import/ImportAgentsModal.tsx | 11 +- .../modules/auth/components/SignInError.tsx | 8 +- .../auth/components/SignInProviders.tsx | 3 +- .../files/api/mutations/useUploadFile.ts | 2 +- .../api/mutations/useCreateContext.ts | 2 +- .../api/mutations/useMatchModelProviders.ts | 2 +- .../api/queries/useContextToken.ts | 2 +- .../api/queries/useListContextHistory.ts | 2 +- .../api/queries/useListContexts.ts | 2 +- .../contexts/platform-context.ts | 2 +- .../api/mutations/useDeleteProvider.ts | 2 +- .../providers/api/queries/useListProviders.ts | 2 +- .../runs/api/queries/useBuildA2AClient.ts | 2 +- .../agent-demands/agent-demands-context.ts | 8 +- .../contexts/agent-run/AgentRunProvider.tsx | 2 +- .../a2a/extensions/ui/streaming/index.ts | 20 -- .../a2a/extensions/ui/streaming/schemas.ts | 18 -- .../a2a/extensions/ui/streaming/types.ts | 12 - pnpm-lock.yaml | 288 ++++++++++++++---- 25 files changed, 267 insertions(+), 170 deletions(-) delete mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts delete mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts delete mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py index 2467f7a7..3fe3c5d1 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py @@ -30,7 +30,7 @@ class StreamingExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): - URI = "https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1" + URI = "https://a2a-extensions.adk.kagenti.dev/ui/streaming/v1" DESCRIPTION = "Enables fine-grained streaming of token chunks." diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts index d1186faf..04115242 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts @@ -9,10 +9,10 @@ import type { A2AUiExtension } from '../../../../core/extensions/types'; import { streamingMetadataSchema } from './schemas'; import type { StreamingMetadata } from './types'; -export type { StreamingMetadata, StreamingPatch } from './types'; export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; +export type { StreamingMetadata, StreamingPatch } from './types'; -export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1'; +export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/streaming/v1'; export const streamingExtension: A2AUiExtension = { getUri: () => STREAMING_EXTENSION_URI, diff --git a/apps/adk-ui/src/api/a2a/client.ts b/apps/adk-ui/src/api/a2a/client.ts index 0b3afe09..5328e4ce 100644 --- a/apps/adk-ui/src/api/a2a/client.ts +++ b/apps/adk-ui/src/api/a2a/client.ts @@ -156,14 +156,12 @@ export const buildA2AClient = async ({ taskId = statusUpdate.taskId; // Check for streaming patches in metadata - const patches = extractStreamingPatches( - statusUpdate.metadata as Record | undefined, - ); + const patches = extractStreamingPatches(statusUpdate.metadata); if (patches && taskId) { // Apply patches to draft and emit as a replace update applyPatches(streamingDraft, patches); - const draftParts = (streamingDraft.parts as Array>) ?? []; + const draftParts = Array.isArray(streamingDraft.parts) ? streamingDraft.parts : []; const uiParts: UIMessagePart[] = draftParts .map((part): UIMessagePart | null => { if (typeof part.text === 'string') { diff --git a/apps/adk-ui/src/api/a2a/streaming.ts b/apps/adk-ui/src/api/a2a/streaming.ts index bdf03d1d..9e63db1c 100644 --- a/apps/adk-ui/src/api/a2a/streaming.ts +++ b/apps/adk-ui/src/api/a2a/streaming.ts @@ -9,9 +9,7 @@ import { STREAMING_EXTENSION_URI } from '@kagenti/adk'; /** * Extract streaming patches from status update metadata. */ -export function extractStreamingPatches( - metadata: Record | undefined | null, -): StreamingPatch[] | null { +export function extractStreamingPatches(metadata: Record | undefined | null): StreamingPatch[] | null { if (!metadata) return null; const streamingData = metadata[STREAMING_EXTENSION_URI] as StreamingMetadata | undefined; if (!streamingData?.message_update?.length) return null; @@ -22,7 +20,7 @@ export function extractStreamingPatches( * Resolve a JSON pointer path to get/set a value in a nested object. * Returns [parent, key] for the target location. */ -function resolvePath(obj: Record, path: string): [Record, string] { +function resolvePath(obj: Record, path: string): [unknown, string] { if (path === '' || path === '/') return [obj, '']; const parts = path.split('/').filter(Boolean); @@ -33,11 +31,11 @@ function resolvePath(obj: Record, path: string): [Record)[key]; + current = current[key]; } } - return [current as Record, parts[parts.length - 1]]; + return [current, parts[parts.length - 1]]; } function getByPath(obj: Record, path: string): unknown { @@ -48,7 +46,7 @@ function getByPath(obj: Record, path: string): unknown { if (Array.isArray(current)) { current = current[Number(key)]; } else if (current && typeof current === 'object') { - current = (current as Record)[key]; + current = current[key]; } else { return undefined; } @@ -64,6 +62,10 @@ function cloneValue(v: T): T { /** * Apply a single streaming patch to a draft message object. * Supports: replace, add, str_ins (custom string insertion). + * + * Note: RFC 6902 `remove` and `move` ops are intentionally not implemented. + * The Python accumulator (server side) only emits replace/add/str_ins deltas + * for progressive text construction. If that changes, add them here. */ function applyPatch(draft: Record, patch: StreamingPatch): void { const { op, path, value } = patch; @@ -118,6 +120,8 @@ function applyPatch(draft: Record, patch: StreamingPatch): void } } } + } else { + console.warn(`Unsupported streaming patch op: "${op}". Skipping.`); } } @@ -125,10 +129,7 @@ function applyPatch(draft: Record, patch: StreamingPatch): void * Apply an array of streaming patches to a draft message object. * Mutates the draft in place and returns it. */ -export function applyPatches( - draft: Record, - patches: StreamingPatch[], -): Record { +export function applyPatches(draft: Record, patches: StreamingPatch[]): Record { for (const patch of patches) { applyPatch(draft, patch); } diff --git a/apps/adk-ui/src/app/(auth)/auth.ts b/apps/adk-ui/src/app/(auth)/auth.ts index b8755bc2..12a00bc0 100644 --- a/apps/adk-ui/src/app/(auth)/auth.ts +++ b/apps/adk-ui/src/app/(auth)/auth.ts @@ -62,7 +62,7 @@ export function getProvider(): ProviderWithId | null { if (!name || !id || !clientId || !issuer) { throw new Error( - 'Missing OIDC provider configuration. Set OIDC_PROVIDER_NAME, OIDC_PROVIDER_ID, OIDC_PROVIDER_CLIENT_ID, OIDC_PROVIDER_CLIENT_SECRET, and OIDC_PROVIDER_ISSUER.', + 'Missing OIDC provider configuration. Set OIDC_PROVIDER_NAME, OIDC_PROVIDER_ID, OIDC_PROVIDER_CLIENT_ID, and OIDC_PROVIDER_ISSUER.', ); } diff --git a/apps/adk-ui/src/modules/agents/api/queries/useListAgents.ts b/apps/adk-ui/src/modules/agents/api/queries/useListAgents.ts index 2108db9d..414d027b 100644 --- a/apps/adk-ui/src/modules/agents/api/queries/useListAgents.ts +++ b/apps/adk-ui/src/modules/agents/api/queries/useListAgents.ts @@ -3,12 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { type ListProvidersRequest, type ListProvidersResponse, ProviderState } from '@kagenti/adk'; import { useQuery } from '@tanstack/react-query'; -import { - type ListProvidersRequest, - type ListProvidersResponse, - ProviderState, -} from '@kagenti/adk'; import { buildAgent, isAgentUiSupported, sortAgentsByName, sortProvidersBy } from '#modules/agents/utils.ts'; import { listProviders } from '#modules/providers/api/index.ts'; @@ -35,9 +31,7 @@ export function useListAgents({ includeUnsupportedUi, includeOffline, orderBy, i } if (!includeOffline) { - items = items.filter( - ({ state }) => state !== ProviderState.Offline, - ); + items = items.filter(({ state }) => state !== ProviderState.Offline); } let agents = items.map(buildAgent); diff --git a/apps/adk-ui/src/modules/agents/components/import/ImportAgentsModal.tsx b/apps/adk-ui/src/modules/agents/components/import/ImportAgentsModal.tsx index 31f7b6ea..28a39cad 100644 --- a/apps/adk-ui/src/modules/agents/components/import/ImportAgentsModal.tsx +++ b/apps/adk-ui/src/modules/agents/components/import/ImportAgentsModal.tsx @@ -3,8 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Button, InlineLoading, InlineNotification, ModalBody, ModalFooter, ModalHeader, TextInput } from '@carbon/react'; -import clsx from 'clsx'; +import { + Button, + InlineLoading, + InlineNotification, + ModalBody, + ModalFooter, + ModalHeader, + TextInput, +} from '@carbon/react'; import { useId } from 'react'; import { useForm } from 'react-hook-form'; diff --git a/apps/adk-ui/src/modules/auth/components/SignInError.tsx b/apps/adk-ui/src/modules/auth/components/SignInError.tsx index d46b6f3f..75b247bf 100644 --- a/apps/adk-ui/src/modules/auth/components/SignInError.tsx +++ b/apps/adk-ui/src/modules/auth/components/SignInError.tsx @@ -22,13 +22,7 @@ export function SignInError({ message, callbackUrl = routes.home() }: Props) { return (
- + diff --git a/apps/adk-ui/src/modules/auth/components/SignInProviders.tsx b/apps/adk-ui/src/modules/auth/components/SignInProviders.tsx index b2661fb6..1741c286 100644 --- a/apps/adk-ui/src/modules/auth/components/SignInProviders.tsx +++ b/apps/adk-ui/src/modules/auth/components/SignInProviders.tsx @@ -54,10 +54,9 @@ async function handleSignIn( } } - const AUTH_ERROR_MESSAGES: Record = { Configuration: 'Unable to connect to the identity provider. Please verify that the authentication service is running and correctly configured.', IdentityProviderUnavailable: 'Unable to connect to the identity provider. Please verify that the authentication service is running and try again.', -}; \ No newline at end of file +}; diff --git a/apps/adk-ui/src/modules/files/api/mutations/useUploadFile.ts b/apps/adk-ui/src/modules/files/api/mutations/useUploadFile.ts index 4495735d..dd79d25f 100644 --- a/apps/adk-ui/src/modules/files/api/mutations/useUploadFile.ts +++ b/apps/adk-ui/src/modules/files/api/mutations/useUploadFile.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useMutation } from '@tanstack/react-query'; import type { CreateFileResponse } from '@kagenti/adk'; +import { useMutation } from '@tanstack/react-query'; import { uploadFile } from '..'; import type { UploadFileParams } from '../types'; diff --git a/apps/adk-ui/src/modules/platform-context/api/mutations/useCreateContext.ts b/apps/adk-ui/src/modules/platform-context/api/mutations/useCreateContext.ts index 77cbb3af..80c68bcb 100644 --- a/apps/adk-ui/src/modules/platform-context/api/mutations/useCreateContext.ts +++ b/apps/adk-ui/src/modules/platform-context/api/mutations/useCreateContext.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useMutation } from '@tanstack/react-query'; import type { Context } from '@kagenti/adk'; +import { useMutation } from '@tanstack/react-query'; import { createContext } from '..'; import { contextKeys } from '../keys'; diff --git a/apps/adk-ui/src/modules/platform-context/api/mutations/useMatchModelProviders.ts b/apps/adk-ui/src/modules/platform-context/api/mutations/useMatchModelProviders.ts index 57729a36..c353ecaa 100644 --- a/apps/adk-ui/src/modules/platform-context/api/mutations/useMatchModelProviders.ts +++ b/apps/adk-ui/src/modules/platform-context/api/mutations/useMatchModelProviders.ts @@ -3,9 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useQuery } from '@tanstack/react-query'; import type { EmbeddingDemands, LLMDemands } from '@kagenti/adk'; import { ModelCapability } from '@kagenti/adk'; +import { useQuery } from '@tanstack/react-query'; import { useEffect } from 'react'; import { useApp } from '#contexts/App/index.ts'; diff --git a/apps/adk-ui/src/modules/platform-context/api/queries/useContextToken.ts b/apps/adk-ui/src/modules/platform-context/api/queries/useContextToken.ts index 4c5cbb85..e508c286 100644 --- a/apps/adk-ui/src/modules/platform-context/api/queries/useContextToken.ts +++ b/apps/adk-ui/src/modules/platform-context/api/queries/useContextToken.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useQuery } from '@tanstack/react-query'; import type { ContextToken } from '@kagenti/adk'; +import { useQuery } from '@tanstack/react-query'; import { useApp } from '#contexts/App/index.ts'; import type { Agent } from '#modules/agents/api/types.ts'; diff --git a/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts b/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts index 61e4ac3e..395c0ea1 100644 --- a/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts +++ b/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useInfiniteQuery } from '@tanstack/react-query'; import type { ListContextHistoryRequest, ListContextHistoryResponse } from '@kagenti/adk'; +import { useInfiniteQuery } from '@tanstack/react-query'; import type { PartialBy } from '#@types/utils.ts'; import { isNotNull } from '#utils/helpers.ts'; diff --git a/apps/adk-ui/src/modules/platform-context/api/queries/useListContexts.ts b/apps/adk-ui/src/modules/platform-context/api/queries/useListContexts.ts index 0d4d7187..42db3ae3 100644 --- a/apps/adk-ui/src/modules/platform-context/api/queries/useListContexts.ts +++ b/apps/adk-ui/src/modules/platform-context/api/queries/useListContexts.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useInfiniteQuery } from '@tanstack/react-query'; import type { ListContextsRequest } from '@kagenti/adk'; +import { useInfiniteQuery } from '@tanstack/react-query'; import { isNotNull } from '#utils/helpers.ts'; diff --git a/apps/adk-ui/src/modules/platform-context/contexts/platform-context.ts b/apps/adk-ui/src/modules/platform-context/contexts/platform-context.ts index 8536c393..14a2510b 100644 --- a/apps/adk-ui/src/modules/platform-context/contexts/platform-context.ts +++ b/apps/adk-ui/src/modules/platform-context/contexts/platform-context.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { UseMutateAsyncFunction } from '@tanstack/react-query'; import type { Context, CreateContextRequest, ListContextHistoryResponse } from '@kagenti/adk'; +import type { UseMutateAsyncFunction } from '@tanstack/react-query'; import { createContext } from 'react'; import type { Agent } from '#modules/agents/api/types.ts'; diff --git a/apps/adk-ui/src/modules/providers/api/mutations/useDeleteProvider.ts b/apps/adk-ui/src/modules/providers/api/mutations/useDeleteProvider.ts index 05b400ea..f8b9ab7e 100644 --- a/apps/adk-ui/src/modules/providers/api/mutations/useDeleteProvider.ts +++ b/apps/adk-ui/src/modules/providers/api/mutations/useDeleteProvider.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useMutation, useQueryClient } from '@tanstack/react-query'; import type { ListProvidersResponse } from '@kagenti/adk'; +import { useMutation, useQueryClient } from '@tanstack/react-query'; import { providerKeys } from '#modules/providers/api/keys.ts'; diff --git a/apps/adk-ui/src/modules/providers/api/queries/useListProviders.ts b/apps/adk-ui/src/modules/providers/api/queries/useListProviders.ts index 4c863b54..e208def4 100644 --- a/apps/adk-ui/src/modules/providers/api/queries/useListProviders.ts +++ b/apps/adk-ui/src/modules/providers/api/queries/useListProviders.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useQuery } from '@tanstack/react-query'; import type { ListProvidersRequest } from '@kagenti/adk'; +import { useQuery } from '@tanstack/react-query'; import { listProviders } from '..'; import { providerKeys } from '../keys'; diff --git a/apps/adk-ui/src/modules/runs/api/queries/useBuildA2AClient.ts b/apps/adk-ui/src/modules/runs/api/queries/useBuildA2AClient.ts index 09eae428..44714774 100644 --- a/apps/adk-ui/src/modules/runs/api/queries/useBuildA2AClient.ts +++ b/apps/adk-ui/src/modules/runs/api/queries/useBuildA2AClient.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useQuery } from '@tanstack/react-query'; import type { TaskStatusUpdateEvent } from '@kagenti/adk'; +import { useQuery } from '@tanstack/react-query'; import { buildA2AClient } from '#api/a2a/client.ts'; diff --git a/apps/adk-ui/src/modules/runs/contexts/agent-demands/agent-demands-context.ts b/apps/adk-ui/src/modules/runs/contexts/agent-demands/agent-demands-context.ts index 46a977bd..14e663ab 100644 --- a/apps/adk-ui/src/modules/runs/contexts/agent-demands/agent-demands-context.ts +++ b/apps/adk-ui/src/modules/runs/contexts/agent-demands/agent-demands-context.ts @@ -3,13 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { - ApprovalDecision, - FormDemands, - Fulfillments, - SettingsFormRender, - SettingsFormValues, -} from '@kagenti/adk'; +import type { ApprovalDecision, FormDemands, Fulfillments, SettingsFormRender, SettingsFormValues } from '@kagenti/adk'; import { createContext } from 'react'; import type { RunFormValues } from '#modules/form/types.ts'; diff --git a/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx b/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx index 5bed7bcb..00dc2593 100644 --- a/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx +++ b/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx @@ -4,9 +4,9 @@ */ 'use client'; -import { useQueryClient } from '@tanstack/react-query'; import type { ApprovalDecision } from '@kagenti/adk'; import { TaskStatusUpdateType } from '@kagenti/adk'; +import { useQueryClient } from '@tanstack/react-query'; import type { PropsWithChildren } from 'react'; import { useCallback, useMemo, useRef, useState } from 'react'; import { match } from 'ts-pattern'; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts deleted file mode 100644 index d1186faf..00000000 --- a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts +++ /dev/null @@ -1,20 +0,0 @@ -/** - * Copyright 2025 © BeeAI a Series of LF Projects, LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import z from 'zod'; - -import type { A2AUiExtension } from '../../../../core/extensions/types'; -import { streamingMetadataSchema } from './schemas'; -import type { StreamingMetadata } from './types'; - -export type { StreamingMetadata, StreamingPatch } from './types'; -export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; - -export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1'; - -export const streamingExtension: A2AUiExtension = { - getUri: () => STREAMING_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [STREAMING_EXTENSION_URI]: streamingMetadataSchema }).partial(), -}; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts deleted file mode 100644 index 3cce4408..00000000 --- a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2025 © BeeAI a Series of LF Projects, LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import z from 'zod'; - -export const streamingPatchSchema = z.object({ - op: z.string(), - path: z.string(), - value: z.unknown().optional(), - pos: z.number().optional(), // for str_ins -}); - -export const streamingMetadataSchema = z.object({ - message_update: z.array(streamingPatchSchema).optional(), - message_id: z.string().optional(), -}); diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts deleted file mode 100644 index aa10398f..00000000 --- a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts +++ /dev/null @@ -1,12 +0,0 @@ -/** - * Copyright 2025 © BeeAI a Series of LF Projects, LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import type z from 'zod'; - -import type { streamingMetadataSchema, streamingPatchSchema } from './schemas'; - -export type StreamingPatch = z.infer; - -export type StreamingMetadata = z.infer; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 6c0f339b..858f0e9b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -70,25 +70,6 @@ catalogs: specifier: ^3.1.1 version: 3.1.1 -overrides: - lodash@<4.17.23: 4.17.23 - dompurify@<3.3.2: 3.3.2 - express@>=4.0.0 <4.20.0: 4.21.2 - zod@<3.22.3: 3.22.3 - js-yaml@>=4.0.0 <4.1.1: 4.1.1 - qs@<6.14.2: 6.14.2 - send@>=0.0.0 <0.19.0: 0.19.0 - serve-static@>=1.0.0 <1.16.0: 1.16.0 - cookie@<0.7.0: 0.7.0 - path-to-regexp@<0.1.12: 0.1.12 - immutable@>=5.0.0 <5.1.5: 5.1.5 - minimatch@<3.1.5: 3.1.5 - axios@<1.13.5: 1.13.5 - svgo@>=3.0.0 <3.3.3: 3.3.3 - body-parser@<1.20.3: 1.20.3 - tar@<7.5.10: 7.5.10 - '@mintlify/previewing>tar': 6.2.1 - importers: apps/adk-ts: @@ -97,7 +78,7 @@ importers: specifier: ^0.3.10 version: 0.3.10(@bufbuild/protobuf@2.11.0)(express@4.21.2) express: - specifier: 4.21.2 + specifier: ^4.18.0 || ^5.0.0 version: 4.21.2 zod: specifier: ^4.3.6 @@ -3929,8 +3910,11 @@ packages: resolution: {integrity: sha512-BASOg+YwO2C+346x3LZOeoovTIoTrRqEsqMa6fmfAV0P+U9mFr9NsyOEpiYvFjbc64NMrSswhV50WdXzdb/Z5A==} engines: {node: '>=4'} - axios@1.13.5: - resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==} + axios@1.10.0: + resolution: {integrity: sha512-/1xYAC4MP/HEG+3duIhFr4ZQXR4sQXOIe+o6sdqzeykGLx6Upp/1p8MHqhINOvGeP7xyNHe7tsiJByc4SSVUxw==} + + axios@1.13.2: + resolution: {integrity: sha512-VPk9ebNqPcy5lRGuSlKx752IlDatOjT9paPlm8A7yOuW2Fbvp4X3JznJtT4f0GzGLLiWE9W8onz51SqLYwzGaA==} axobject-query@4.1.0: resolution: {integrity: sha512-qIj0G9wZbMGNLjLmg1PT6v2mE9AH2zlnADJD/2tC6E00hgmhUOfEB6greHPAfLRSufHqROIUTkw6E+M3lH0PTQ==} @@ -4034,6 +4018,10 @@ packages: resolution: {integrity: sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==} engines: {node: '>=8'} + body-parser@1.20.1: + resolution: {integrity: sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==} + engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} + body-parser@1.20.3: resolution: {integrity: sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==} engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} @@ -4324,8 +4312,12 @@ packages: cookie-signature@1.0.6: resolution: {integrity: sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==} - cookie@0.7.0: - resolution: {integrity: sha512-qCf+V4dtlNhSRXGAZatc1TasyFO6GjohcOul807YOb5ik3+kQSnb4d7iajeCL8QHaJ4uZEjCgiCJerKXwdRVlQ==} + cookie@0.4.2: + resolution: {integrity: sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==} + engines: {node: '>= 0.6'} + + cookie@0.5.0: + resolution: {integrity: sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==} engines: {node: '>= 0.6'} cookie@0.7.1: @@ -5081,6 +5073,10 @@ packages: resolution: {integrity: sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==} engines: {node: '>=12.0.0'} + express@4.18.2: + resolution: {integrity: sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==} + engines: {node: '>= 0.10.0'} + express@4.21.2: resolution: {integrity: sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==} engines: {node: '>= 0.10.0'} @@ -5163,6 +5159,10 @@ packages: resolution: {integrity: sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==} engines: {node: '>=8'} + finalhandler@1.2.0: + resolution: {integrity: sha512-5uXcUVftlQMFnWC9qu/svkWv3GTd2PfUhK/3PLkYNAe7FbqJMt3515HaxE6eRL74GdsriiwujiawdaB1BpEISg==} + engines: {node: '>= 0.8'} + finalhandler@1.3.1: resolution: {integrity: sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==} engines: {node: '>= 0.8'} @@ -5856,6 +5856,10 @@ packages: resolution: {integrity: sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==} hasBin: true + js-yaml@4.1.0: + resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} + hasBin: true + js-yaml@4.1.1: resolution: {integrity: sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==} hasBin: true @@ -6001,6 +6005,9 @@ packages: lodash.truncate@4.4.2: resolution: {integrity: sha512-jttmRe7bRse52OsWIMDLaXxWqRAmtIUccAQ3garviCqJjafXOfNMO0yMfNpdD6zbGaTU0P5Nz7e7gAT6cKmJRw==} + lodash@4.17.21: + resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==} + lodash@4.17.23: resolution: {integrity: sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==} @@ -6133,6 +6140,9 @@ packages: resolution: {integrity: sha512-EDYo6VlmtnumlcBCbh1gLJ//9jvM/ndXHfVXIFrZVr6fGcwTUyCTFNTLCKuY3ffbK8L/+3Mzqnd58RojiZqHVw==} engines: {node: '>=20'} + merge-descriptors@1.0.1: + resolution: {integrity: sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==} + merge-descriptors@1.0.3: resolution: {integrity: sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==} @@ -6291,6 +6301,9 @@ packages: resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==} engines: {node: 18 || 20 || >=22} + minimatch@3.1.2: + resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==} + minimatch@3.1.5: resolution: {integrity: sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==} @@ -6646,6 +6659,9 @@ packages: path-to-regexp@0.1.12: resolution: {integrity: sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==} + path-to-regexp@0.1.7: + resolution: {integrity: sha512-5DFkuoqlv1uYQKxy8omFBeJPQcdoE07Kv2sferDCrAq1ohOU+MSDswDIbnx3YAM60qIOnYa53wBhXW0EbMonrQ==} + path-type@4.0.0: resolution: {integrity: sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==} engines: {node: '>=8'} @@ -6858,8 +6874,12 @@ packages: resolution: {integrity: sha512-tsSGN1x3h569ZSU1u6diwhltLyfUWDp3YbFHedapTmpBl0B3P6U3+Qptg7xu+v+1io1EwhdPyyRHYbEw0KN2FA==} engines: {node: '>=20'} - qs@6.14.2: - resolution: {integrity: sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==} + qs@6.11.0: + resolution: {integrity: sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==} + engines: {node: '>=0.6'} + + qs@6.13.0: + resolution: {integrity: sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==} engines: {node: '>=0.6'} queue-microtask@1.2.3: @@ -6873,6 +6893,10 @@ packages: resolution: {integrity: sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==} engines: {node: '>= 0.6'} + raw-body@2.5.1: + resolution: {integrity: sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==} + engines: {node: '>= 0.8'} + raw-body@2.5.2: resolution: {integrity: sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==} engines: {node: '>= 0.8'} @@ -7364,6 +7388,10 @@ packages: engines: {node: '>=10'} hasBin: true + send@0.18.0: + resolution: {integrity: sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==} + engines: {node: '>= 0.8.0'} + send@0.19.0: resolution: {integrity: sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==} engines: {node: '>= 0.8.0'} @@ -7372,6 +7400,10 @@ packages: resolution: {integrity: sha512-bBZaRwLH9PN5HbLCjPId4dP5bNGEtumcErgOX952IsvOhVPrm3/AeK1y0UHA/QaPG701eg0yEnOKsCOC6X/kaA==} engines: {node: '>=20'} + serve-static@1.15.0: + resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==} + engines: {node: '>= 0.8.0'} + serve-static@1.16.2: resolution: {integrity: sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==} engines: {node: '>= 0.8.0'} @@ -7719,9 +7751,10 @@ packages: tar-stream@3.1.8: resolution: {integrity: sha512-U6QpVRyCGHva435KoNWy9PRoi2IFYCgtEhq9nmrPPpbRacPs9IH4aJ3gbrFC8dPcXvdSZ4XXfXT5Fshbp2MtlQ==} - tar@6.2.1: - resolution: {integrity: sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==} + tar@6.1.15: + resolution: {integrity: sha512-/zKt9UyngnxIT/EAGYuxaMYgOIJiP81ab9ZfkILq4oNLPFX50qyYmu7jRj9qeXoxmJHjGlbH0+cm2uy1WCs10A==} engines: {node: '>=10'} + deprecated: Old versions of tar are not supported, and contain widely publicized security vulnerabilities, which have been fixed in the current version. Please update. Support for old versions may be purchased (at exorbitant rates) by contacting i@izs.me teex@1.0.1: resolution: {integrity: sha512-eYE6iEI62Ni1H8oIa7KlDU6uQBtqr4Eajni3wX7rpfXD8ysFx8z0+dri+KWEPWpBsxXfxu58x/0jvTVT1ekOSg==} @@ -8366,10 +8399,10 @@ packages: zod-to-json-schema@3.20.4: resolution: {integrity: sha512-Un9+kInJ2Zt63n6Z7mLqBifzzPcOyX+b+Exuzf7L1+xqck9Q2EPByyTRduV3kmSPaXaRer1JCsucubpgL1fipg==} peerDependencies: - zod: 3.22.3 + zod: ^3.20.0 - zod@3.22.3: - resolution: {integrity: sha512-EjIevzuJRiRPbVH4mGc8nApb/lVLKVpmUhAaR5R5doKGfAnGJ6Gr3CViAVjP+4FWSxCsybeWQdcgCtbX+7oZug==} + zod@3.21.4: + resolution: {integrity: sha512-m46AKbrzKVzOzs/DZgVnG5H55N1sv1M8qZU3A8RIKbs3mrACDNeIOeilDymVb2HdmP8uwshOCF4uJ8uM9rCqJw==} zod@3.23.8: resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==} @@ -10053,7 +10086,7 @@ snapshots: fs-extra: 11.2.0 ink: 6.3.0(@types/react@19.2.14)(react@19.2.3) inquirer: 12.3.0(@types/node@25.4.0) - js-yaml: 4.1.1 + js-yaml: 4.1.0 mdast-util-mdx-jsx: 3.2.0 react: 19.2.3 semver: 7.7.2 @@ -10096,8 +10129,8 @@ snapshots: hast-util-to-text: 4.0.2 hex-rgb: 5.0.0 ignore: 7.0.5 - js-yaml: 4.1.1 - lodash: 4.17.23 + js-yaml: 4.1.0 + lodash: 4.17.21 mdast-util-from-markdown: 2.0.2 mdast-util-gfm: 3.0.0 mdast-util-mdx: 3.0.0 @@ -10157,8 +10190,8 @@ snapshots: hast-util-to-text: 4.0.2 hex-rgb: 5.0.0 ignore: 7.0.5 - js-yaml: 4.1.1 - lodash: 4.17.23 + js-yaml: 4.1.0 + lodash: 4.17.21 mdast-util-from-markdown: 2.0.2 mdast-util-gfm: 3.0.0 mdast-util-mdx: 3.0.0 @@ -10253,14 +10286,14 @@ snapshots: '@mintlify/models@0.0.255': dependencies: - axios: 1.13.5 + axios: 1.10.0 openapi-types: 12.1.3 transitivePeerDependencies: - debug '@mintlify/models@0.0.283': dependencies: - axios: 1.13.5 + axios: 1.13.2 openapi-types: 12.1.3 transitivePeerDependencies: - debug @@ -10284,7 +10317,7 @@ snapshots: favicons: 7.2.0 front-matter: 4.0.2 fs-extra: 11.1.0 - js-yaml: 4.1.1 + js-yaml: 4.1.0 openapi-types: 12.1.3 sharp: 0.33.5 sharp-ico: 0.1.5 @@ -10314,18 +10347,18 @@ snapshots: better-opn: 3.0.2 chalk: 5.2.0 chokidar: 3.5.3 - express: 4.21.2 + express: 4.18.2 front-matter: 4.0.2 fs-extra: 11.1.0 got: 13.0.0 ink: 6.3.0(@types/react@19.2.14)(react@19.2.3) ink-spinner: 5.0.0(ink@6.3.0(@types/react@19.2.14)(react@19.2.3))(react@19.2.3) is-online: 10.0.0 - js-yaml: 4.1.1 + js-yaml: 4.1.0 openapi-types: 12.1.3 react: 19.2.3 socket.io: 4.7.2 - tar: 6.2.1 + tar: 6.1.15 unist-util-visit: 4.1.2 yargs: 17.7.1 transitivePeerDependencies: @@ -10350,7 +10383,7 @@ snapshots: '@mintlify/openapi-parser': 0.0.8 fs-extra: 11.1.1 hast-util-to-mdast: 10.1.0 - js-yaml: 4.1.1 + js-yaml: 4.1.0 mdast-util-mdx-jsx: 3.1.3 neotraverse: 0.6.18 puppeteer: 22.14.0(typescript@5.9.3) @@ -10362,7 +10395,7 @@ snapshots: unified: 11.0.5 unist-util-visit: 5.0.0 yargs: 17.7.1 - zod: 3.22.3 + zod: 3.21.4 transitivePeerDependencies: - '@radix-ui/react-popover' - '@types/react' @@ -10385,7 +10418,7 @@ snapshots: '@mintlify/openapi-parser': 0.0.8 fs-extra: 11.1.1 hast-util-to-mdast: 10.1.0 - js-yaml: 4.1.1 + js-yaml: 4.1.0 mdast-util-mdx-jsx: 3.1.3 neotraverse: 0.6.18 puppeteer: 22.14.0(typescript@5.9.3) @@ -10419,14 +10452,14 @@ snapshots: '@mintlify/mdx': 3.0.4(@radix-ui/react-popover@1.1.15(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.3))(react@19.2.3))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.3))(react@19.2.3)(typescript@5.9.3) '@mintlify/models': 0.0.255 arktype: 2.1.27 - js-yaml: 4.1.1 + js-yaml: 4.1.0 lcm: 0.0.3 - lodash: 4.17.23 + lodash: 4.17.21 object-hash: 3.0.0 openapi-types: 12.1.3 uuid: 11.1.0 - zod: 3.22.3 - zod-to-json-schema: 3.20.4(zod@3.22.3) + zod: 3.21.4 + zod-to-json-schema: 3.20.4(zod@3.21.4) transitivePeerDependencies: - '@radix-ui/react-popover' - '@types/react' @@ -10441,9 +10474,9 @@ snapshots: '@mintlify/mdx': 3.0.4(@radix-ui/react-popover@1.1.15(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.3))(react@19.2.3))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.3))(react@19.2.3)(typescript@5.9.3) '@mintlify/models': 0.0.283 arktype: 2.1.27 - js-yaml: 4.1.1 + js-yaml: 4.1.0 lcm: 0.0.3 - lodash: 4.17.23 + lodash: 4.17.21 object-hash: 3.0.0 openapi-types: 12.1.3 uuid: 11.1.0 @@ -11987,7 +12020,7 @@ snapshots: jsonpath-plus: 10.4.0 lodash: 4.17.23 lodash.topath: 4.5.2 - minimatch: 3.1.5 + minimatch: 3.1.2 nimma: 0.2.3 pony-cause: 1.1.1 simple-eval: 1.0.1 @@ -12875,7 +12908,15 @@ snapshots: axe-core@4.11.1: {} - axios@1.13.5: + axios@1.10.0: + dependencies: + follow-redirects: 1.15.11 + form-data: 4.0.5 + proxy-from-env: 1.1.0 + transitivePeerDependencies: + - debug + + axios@1.13.2: dependencies: follow-redirects: 1.15.11 form-data: 4.0.5 @@ -12968,6 +13009,23 @@ snapshots: binary-extensions@2.3.0: {} + body-parser@1.20.1: + dependencies: + bytes: 3.1.2 + content-type: 1.0.5 + debug: 2.6.9 + depd: 2.0.0 + destroy: 1.2.0 + http-errors: 2.0.0 + iconv-lite: 0.4.24 + on-finished: 2.4.1 + qs: 6.11.0 + raw-body: 2.5.1 + type-is: 1.6.18 + unpipe: 1.0.0 + transitivePeerDependencies: + - supports-color + body-parser@1.20.3: dependencies: bytes: 3.1.2 @@ -12978,7 +13036,7 @@ snapshots: http-errors: 2.0.0 iconv-lite: 0.4.24 on-finished: 2.4.1 - qs: 6.14.2 + qs: 6.13.0 raw-body: 2.5.2 type-is: 1.6.18 unpipe: 1.0.0 @@ -13256,7 +13314,9 @@ snapshots: cookie-signature@1.0.6: {} - cookie@0.7.0: {} + cookie@0.4.2: {} + + cookie@0.5.0: {} cookie@0.7.1: {} @@ -13749,7 +13809,7 @@ snapshots: '@types/node': 25.4.0 accepts: 1.3.8 base64id: 2.0.0 - cookie: 0.7.0 + cookie: 0.4.2 cors: 2.8.6 debug: 4.3.7 engine.io-parser: 5.2.3 @@ -13995,7 +14055,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-module-utils@2.12.1(@typescript-eslint/parser@8.56.1(eslint@9.39.3(jiti@1.21.7))(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@9.39.3(jiti@1.21.7)): + eslint-module-utils@2.12.1(@typescript-eslint/parser@8.56.1(eslint@9.39.3(jiti@1.21.7))(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@9.39.3(jiti@1.21.7)))(eslint@9.39.3(jiti@1.21.7)): dependencies: debug: 3.2.7 optionalDependencies: @@ -14023,7 +14083,7 @@ snapshots: doctrine: 2.1.0 eslint: 9.39.3(jiti@1.21.7) eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.56.1(eslint@9.39.3(jiti@1.21.7))(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@9.39.3(jiti@1.21.7)) + eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.56.1(eslint@9.39.3(jiti@1.21.7))(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@9.39.3(jiti@1.21.7)))(eslint@9.39.3(jiti@1.21.7)) hasown: 2.0.2 is-core-module: 2.16.1 is-glob: 4.0.3 @@ -14209,6 +14269,42 @@ snapshots: expect-type@1.3.0: {} + express@4.18.2: + dependencies: + accepts: 1.3.8 + array-flatten: 1.1.1 + body-parser: 1.20.1 + content-disposition: 0.5.4 + content-type: 1.0.5 + cookie: 0.5.0 + cookie-signature: 1.0.6 + debug: 2.6.9 + depd: 2.0.0 + encodeurl: 1.0.2 + escape-html: 1.0.3 + etag: 1.8.1 + finalhandler: 1.2.0 + fresh: 0.5.2 + http-errors: 2.0.0 + merge-descriptors: 1.0.1 + methods: 1.1.2 + on-finished: 2.4.1 + parseurl: 1.3.3 + path-to-regexp: 0.1.7 + proxy-addr: 2.0.7 + qs: 6.11.0 + range-parser: 1.2.1 + safe-buffer: 5.2.1 + send: 0.18.0 + serve-static: 1.15.0 + setprototypeof: 1.2.0 + statuses: 2.0.1 + type-is: 1.6.18 + utils-merge: 1.0.1 + vary: 1.1.2 + transitivePeerDependencies: + - supports-color + express@4.21.2: dependencies: accepts: 1.3.8 @@ -14232,7 +14328,7 @@ snapshots: parseurl: 1.3.3 path-to-regexp: 0.1.12 proxy-addr: 2.0.7 - qs: 6.14.2 + qs: 6.13.0 range-parser: 1.2.1 safe-buffer: 5.2.1 send: 0.19.0 @@ -14329,6 +14425,18 @@ snapshots: dependencies: to-regex-range: 5.0.1 + finalhandler@1.2.0: + dependencies: + debug: 2.6.9 + encodeurl: 1.0.2 + escape-html: 1.0.3 + on-finished: 2.4.1 + parseurl: 1.3.3 + statuses: 2.0.1 + unpipe: 1.0.0 + transitivePeerDependencies: + - supports-color + finalhandler@1.3.1: dependencies: debug: 2.6.9 @@ -15159,6 +15267,10 @@ snapshots: argparse: 1.0.10 esprima: 4.0.1 + js-yaml@4.1.0: + dependencies: + argparse: 2.0.1 + js-yaml@4.1.1: dependencies: argparse: 2.0.1 @@ -15279,6 +15391,8 @@ snapshots: lodash.truncate@4.4.2: {} + lodash@4.17.21: {} + lodash@4.17.23: {} longest-streak@3.1.0: {} @@ -15562,6 +15676,8 @@ snapshots: meow@14.1.0: {} + merge-descriptors@1.0.1: {} + merge-descriptors@1.0.3: {} merge2@1.4.1: {} @@ -15896,6 +16012,10 @@ snapshots: dependencies: brace-expansion: 5.0.4 + minimatch@3.1.2: + dependencies: + brace-expansion: 1.1.12 + minimatch@3.1.5: dependencies: brace-expansion: 1.1.12 @@ -16283,6 +16403,8 @@ snapshots: path-to-regexp@0.1.12: {} + path-to-regexp@0.1.7: {} + path-type@4.0.0: {} pathe@2.0.3: {} @@ -16486,7 +16608,11 @@ snapshots: dependencies: hookified: 1.15.1 - qs@6.14.2: + qs@6.11.0: + dependencies: + side-channel: 1.1.0 + + qs@6.13.0: dependencies: side-channel: 1.1.0 @@ -16496,6 +16622,13 @@ snapshots: range-parser@1.2.1: {} + raw-body@2.5.1: + dependencies: + bytes: 3.1.2 + http-errors: 2.0.0 + iconv-lite: 0.4.24 + unpipe: 1.0.0 + raw-body@2.5.2: dependencies: bytes: 3.1.2 @@ -17167,6 +17300,24 @@ snapshots: semver@7.7.4: {} + send@0.18.0: + dependencies: + debug: 2.6.9 + depd: 2.0.0 + destroy: 1.2.0 + encodeurl: 1.0.2 + escape-html: 1.0.3 + etag: 1.8.1 + fresh: 0.5.2 + http-errors: 2.0.0 + mime: 1.6.0 + ms: 2.1.3 + on-finished: 2.4.1 + range-parser: 1.2.1 + statuses: 2.0.1 + transitivePeerDependencies: + - supports-color + send@0.19.0: dependencies: debug: 2.6.9 @@ -17190,6 +17341,15 @@ snapshots: non-error: 0.1.0 type-fest: 5.4.4 + serve-static@1.15.0: + dependencies: + encodeurl: 1.0.2 + escape-html: 1.0.3 + parseurl: 1.3.3 + send: 0.18.0 + transitivePeerDependencies: + - supports-color + serve-static@1.16.2: dependencies: encodeurl: 2.0.0 @@ -17752,7 +17912,7 @@ snapshots: - bare-buffer - react-native-b4a - tar@6.2.1: + tar@6.1.15: dependencies: chownr: 2.0.0 fs-minipass: 2.1.0 @@ -18432,15 +18592,15 @@ snapshots: yoga-layout@3.2.1: {} - zod-to-json-schema@3.20.4(zod@3.22.3): + zod-to-json-schema@3.20.4(zod@3.21.4): dependencies: - zod: 3.22.3 + zod: 3.21.4 zod-to-json-schema@3.20.4(zod@3.24.0): dependencies: zod: 3.24.0 - zod@3.22.3: {} + zod@3.21.4: {} zod@3.23.8: {} From ebb53268298b5936bf6c77512c617631107ee22d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Bula=CC=81nek?= Date: Fri, 27 Mar 2026 16:54:07 +0100 Subject: [PATCH 3/4] refactor(sdk): align citation and trajectory extension schemas with array pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Petr Bulánek --- apps/adk-ts/src/client/a2a/extensions/schemas.ts | 1 + .../adk-ts/src/client/a2a/extensions/ui/citation/index.ts | 8 ++++---- .../src/client/a2a/extensions/ui/citation/schemas.ts | 4 +--- .../src/client/a2a/extensions/ui/streaming/index.ts | 3 --- .../src/client/a2a/extensions/ui/streaming/schemas.ts | 2 +- .../src/client/a2a/extensions/ui/trajectory/index.ts | 4 ++-- .../src/client/a2a/extensions/ui/trajectory/schemas.ts | 4 +++- .../src/client/a2a/extensions/ui/trajectory/types.ts | 4 +++- apps/adk-ui/src/api/a2a/utils.ts | 7 +++---- 9 files changed, 18 insertions(+), 19 deletions(-) diff --git a/apps/adk-ts/src/client/a2a/extensions/schemas.ts b/apps/adk-ts/src/client/a2a/extensions/schemas.ts index feaa27ec..57ff688f 100644 --- a/apps/adk-ts/src/client/a2a/extensions/schemas.ts +++ b/apps/adk-ts/src/client/a2a/extensions/schemas.ts @@ -17,4 +17,5 @@ export * from './ui/canvas/schemas'; export * from './ui/citation/schemas'; export * from './ui/error/schemas'; export * from './ui/settings/schemas'; +export * from './ui/streaming/schemas'; export * from './ui/trajectory/schemas'; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts index e88140a7..c18c34de 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts @@ -6,12 +6,12 @@ import z from 'zod'; import type { A2AUiExtension } from '../../../../core/extensions/types'; -import { citationSchema } from './schemas'; -import type { Citation } from './types'; +import { citationMetadataSchema } from './schemas'; +import type { CitationMetadata } from './types'; export const CITATION_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/citation/v1'; -export const citationExtension: A2AUiExtension = { +export const citationExtension: A2AUiExtension = { getUri: () => CITATION_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [CITATION_EXTENSION_URI]: z.array(citationSchema) }).partial(), + getMessageMetadataSchema: () => z.object({ [CITATION_EXTENSION_URI]: citationMetadataSchema }).partial(), }; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/citation/schemas.ts b/apps/adk-ts/src/client/a2a/extensions/ui/citation/schemas.ts index 291a0d69..2aaf8ed7 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/citation/schemas.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/citation/schemas.ts @@ -13,6 +13,4 @@ export const citationSchema = z.object({ description: z.string().nullish(), }); -export const citationMetadataSchema = z.object({ - citations: z.array(citationSchema), -}); +export const citationMetadataSchema = z.array(citationSchema); diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts index 04115242..83b49959 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts @@ -9,9 +9,6 @@ import type { A2AUiExtension } from '../../../../core/extensions/types'; import { streamingMetadataSchema } from './schemas'; import type { StreamingMetadata } from './types'; -export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; -export type { StreamingMetadata, StreamingPatch } from './types'; - export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/streaming/v1'; export const streamingExtension: A2AUiExtension = { diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts index 3cce4408..130faf0c 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts @@ -9,7 +9,7 @@ export const streamingPatchSchema = z.object({ op: z.string(), path: z.string(), value: z.unknown().optional(), - pos: z.number().optional(), // for str_ins + pos: z.number().optional(), }); export const streamingMetadataSchema = z.object({ diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts index 7ab8aac8..d18b230d 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts @@ -11,7 +11,7 @@ import type { TrajectoryMetadata } from './types'; export const TRAJECTORY_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/trajectory/v1'; -export const trajectoryExtension: A2AUiExtension = { +export const trajectoryExtension: A2AUiExtension = { getUri: () => TRAJECTORY_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [TRAJECTORY_EXTENSION_URI]: z.array(trajectoryMetadataSchema) }).partial(), + getMessageMetadataSchema: () => z.object({ [TRAJECTORY_EXTENSION_URI]: trajectoryMetadataSchema }).partial(), }; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/schemas.ts b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/schemas.ts index b6a03249..c7342bf8 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/schemas.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/schemas.ts @@ -5,8 +5,10 @@ import z from 'zod'; -export const trajectoryMetadataSchema = z.object({ +export const trajectorySchema = z.object({ title: z.string().nullish(), content: z.string().nullish(), group_id: z.string().nullish(), }); + +export const trajectoryMetadataSchema = z.array(trajectorySchema); diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/types.ts b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/types.ts index e5cbaf19..fc175636 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/types.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/types.ts @@ -5,6 +5,8 @@ import type z from 'zod'; -import type { trajectoryMetadataSchema } from './schemas'; +import type { trajectoryMetadataSchema, trajectorySchema } from './schemas'; + +export type Trajectory = z.infer; export type TrajectoryMetadata = z.infer; diff --git a/apps/adk-ui/src/api/a2a/utils.ts b/apps/adk-ui/src/api/a2a/utils.ts index 303ee976..cd5eeab0 100644 --- a/apps/adk-ui/src/api/a2a/utils.ts +++ b/apps/adk-ui/src/api/a2a/utils.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { FilePart, Message, Part, TextPart } from '@kagenti/adk'; +import type { FilePart, Message, Part, TextPart, Trajectory } from '@kagenti/adk'; import { type Citation, citationExtension, @@ -11,7 +11,6 @@ import { extractUiExtensionData, streamingExtension, trajectoryExtension, - type TrajectoryMetadata, } from '@kagenti/adk'; import truncate from 'lodash/truncate'; import { v4 as uuid } from 'uuid'; @@ -124,8 +123,8 @@ export function createSourcePart(citation: Citation, taskId: string | undefined return sourcePart; } -export function createTrajectoryPart(metadata: TrajectoryMetadata): UITrajectoryPart { - const { title, content, group_id } = metadata; +export function createTrajectoryPart(trajectory: Trajectory): UITrajectoryPart { + const { title, content, group_id } = trajectory; const trajectoryPart: UITrajectoryPart = { kind: UIMessagePartKind.Trajectory, From a2df07403aa304c8c4f0636b43e16bbc3f6cb997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Bula=CC=81nek?= Date: Fri, 27 Mar 2026 17:02:24 +0100 Subject: [PATCH 4/4] fix(ui): tidy agent-card extensions extraction and trajectory part mapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Petr Bulánek --- apps/adk-ui/src/api/a2a/agent-card.ts | 2 +- apps/adk-ui/src/api/a2a/part-processors.ts | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/apps/adk-ui/src/api/a2a/agent-card.ts b/apps/adk-ui/src/api/a2a/agent-card.ts index 2ab80317..43558697 100644 --- a/apps/adk-ui/src/api/a2a/agent-card.ts +++ b/apps/adk-ui/src/api/a2a/agent-card.ts @@ -18,8 +18,8 @@ export async function getAgentClient(providerId: string, token: string): Promise const endpointUrl = `${baseUrl}/api/v1/a2a/${providerId}/`; const agentCard = await fetchAgentCard(agentCardUrl, fetchImpl); + const extensions = agentCard.capabilities.extensions?.map(({ uri }) => uri); - const extensions = agentCard.capabilities?.extensions?.map((ext) => ext.uri).filter(Boolean) as string[]; return createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions }); } diff --git a/apps/adk-ui/src/api/a2a/part-processors.ts b/apps/adk-ui/src/api/a2a/part-processors.ts index e181cf59..78956b0b 100644 --- a/apps/adk-ui/src/api/a2a/part-processors.ts +++ b/apps/adk-ui/src/api/a2a/part-processors.ts @@ -20,15 +20,15 @@ import { } from './utils'; export function processMessageMetadata(message: Message): UIMessagePart[] { - const trajectory = extractTrajectory(message.metadata); + const trajectories = extractTrajectory(message.metadata); const citations = extractCitation(message.metadata); const parts: UIMessagePart[] = []; - if (trajectory) { - for (const item of trajectory) { - parts.push(createTrajectoryPart(item)); - } + if (trajectories) { + const trajectoryParts = trajectories.map((trajectory) => createTrajectoryPart(trajectory)); + + parts.push(...trajectoryParts); } if (citations) { const sourceParts = citations.map((citation) => createSourcePart(citation, message.taskId)).filter(isNotNull);