From ad3c581497cd1249a0e67ed8bd8881f64b311ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Wed, 25 Mar 2026 16:23:24 +0100 Subject: [PATCH 1/2] feat(sdk): add streaming extensions, response accumulator, and UI streaming support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a delta-based streaming protocol using JSON Patch (RFC 6902) with a custom `str_ins` operation for efficient token-by-token text delivery. Server (adk-py): - MessageAccumulator: 3-level state machine collecting string chunks, parts, and metadata into complete messages with incremental patches - Extended JSON Patch (jsonpatch_ext): str_ins operation for optimized text streaming without full-value replacement - StreamingExtensionServer/Client: extension negotiation, patch extraction, delta emission, and graceful fallback for non-streaming clients - Refactored extension base: activation lifecycle, default demand constant consolidation (DEFAULT_DEMAND_NAME), and metadata handling - RunContext: extracted shared store logic into _prepare_store_data, added store_sync for synchronous callers - New yield types: ArtifactChunk for chunked artifact streaming Client (adk-ts / adk-ui): - TypeScript streaming extension types and Zod schemas - UI patch application (streaming.ts): JSON pointer resolution, replace/add/str_ins ops with optimized cloning (skip structuredClone for primitives) - Client streaming integration: draft accumulation, patch extraction from status update metadata, replace-mode emission - Fix: text duplication on final message — detect active streaming session and emit with replace flag to prevent appending duplicate content - Batched React state updates in AgentRunProvider to reduce re-renders during streaming - Removed dead commented-out proxy header code from adk-client Tests: - Unit tests for MessageAccumulator state machine - Unit tests for extended JSON Patch operations - Unit tests for RunContext store methods - E2E streaming tests: patch application, message reconstruction, client fallback behavior - Updated yield tests for new artifact chunking Assisted-By: Claude (Anthropic AI) Signed-off-by: Radek Ježek --- apps/adk-cli/src/kagenti_cli/api.py | 17 +- .../adk-cli/src/kagenti_cli/commands/agent.py | 273 +++++----- apps/adk-cli/uv.lock | 25 + apps/adk-py/pyproject.toml | 2 + .../a2a/extensions/auth/oauth/oauth.py | 24 +- .../a2a/extensions/auth/secrets/secrets.py | 17 +- .../src/kagenti_adk/a2a/extensions/base.py | 62 ++- .../a2a/extensions/interactions/approval.py | 8 +- .../a2a/extensions/services/embedding.py | 28 +- .../a2a/extensions/services/form.py | 2 +- .../a2a/extensions/services/llm.py | 26 +- .../a2a/extensions/services/mcp.py | 23 +- .../a2a/extensions/services/platform.py | 6 +- .../kagenti_adk/a2a/extensions/streaming.py | 335 ++++++++++++ .../kagenti_adk/a2a/extensions/tools/call.py | 8 +- .../kagenti_adk/a2a/extensions/ui/__init__.py | 117 ++++- .../a2a/extensions/ui/agent_detail.py | 2 +- .../kagenti_adk/a2a/extensions/ui/canvas.py | 2 +- .../kagenti_adk/a2a/extensions/ui/citation.py | 6 +- .../kagenti_adk/a2a/extensions/ui/error.py | 2 +- .../a2a/extensions/ui/form_request.py | 2 +- .../kagenti_adk/a2a/extensions/ui/settings.py | 2 +- .../a2a/extensions/ui/trajectory.py | 8 +- .../src/kagenti_adk/server/accumulator.py | 219 ++++++++ apps/adk-py/src/kagenti_adk/server/agent.py | 277 +++++----- .../src/kagenti_adk/server/constants.py | 2 + apps/adk-py/src/kagenti_adk/server/context.py | 37 +- .../src/kagenti_adk/server/dependencies.py | 38 +- .../src/kagenti_adk/server/exceptions.py | 7 +- .../src/kagenti_adk/server/jsonpatch_ext.py | 160 ++++++ apps/adk-py/src/kagenti_adk/server/utils.py | 39 +- apps/adk-py/src/kagenti_adk/types.py | 19 +- apps/adk-py/tests/conftest.py | 9 - apps/adk-py/tests/e2e/conftest.py | 56 ++- apps/adk-py/tests/e2e/test_streaming.py | 476 ++++++++++++++++++ apps/adk-py/tests/e2e/test_yields.py | 135 +++-- apps/adk-py/tests/test_merge_utils.py | 60 +++ .../tests/unit/server/test_accumulator.py | 292 +++++++++++ apps/adk-py/tests/unit/server/test_context.py | 188 +++++++ .../tests/unit/server/test_jsonpatch_ext.py | 78 +++ .../unit/test_agent_detail_population.py | 5 + apps/adk-py/tests/unit/test_dependencies.py | 19 +- apps/adk-py/uv.lock | 25 + apps/adk-server/pyproject.toml | 1 + .../src/adk_server/api/routes/a2a.py | 2 +- .../e2e/agents/test_platform_extensions.py | 5 +- apps/adk-server/uv.lock | 27 + .../adk-ts/src/client/a2a/extensions/index.ts | 1 + .../adk-ts/src/client/a2a/extensions/types.ts | 1 + .../a2a/extensions/ui/citation/index.ts | 8 +- .../a2a/extensions/ui/streaming/index.ts | 20 + .../a2a/extensions/ui/streaming/schemas.ts | 18 + .../a2a/extensions/ui/streaming/types.ts | 12 + .../a2a/extensions/ui/trajectory/index.ts | 4 +- apps/adk-ui/src/api/a2a/agent-card.ts | 3 +- apps/adk-ui/src/api/a2a/client.ts | 33 +- apps/adk-ui/src/api/a2a/jsonrpc-client.ts | 12 +- apps/adk-ui/src/api/a2a/part-processors.ts | 8 +- apps/adk-ui/src/api/a2a/streaming.ts | 136 +++++ apps/adk-ui/src/api/a2a/types.ts | 6 +- apps/adk-ui/src/api/a2a/utils.ts | 2 + apps/adk-ui/src/api/adk-client.ts | 15 - apps/adk-ui/src/app/(auth)/auth.ts | 2 +- .../src/modules/platform-context/constants.ts | 2 +- .../contexts/agent-run/AgentRunProvider.tsx | 17 +- .../a2a/extensions/ui/streaming/index.ts | 20 + .../a2a/extensions/ui/streaming/schemas.ts | 18 + .../a2a/extensions/ui/streaming/types.ts | 12 + 68 files changed, 3028 insertions(+), 495 deletions(-) create mode 100644 apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py create mode 100644 apps/adk-py/src/kagenti_adk/server/accumulator.py create mode 100644 apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py delete mode 100644 apps/adk-py/tests/conftest.py create mode 100644 apps/adk-py/tests/e2e/test_streaming.py create mode 100644 apps/adk-py/tests/test_merge_utils.py create mode 100644 apps/adk-py/tests/unit/server/test_accumulator.py create mode 100644 apps/adk-py/tests/unit/server/test_context.py create mode 100644 apps/adk-py/tests/unit/server/test_jsonpatch_ext.py create mode 100644 apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts create mode 100644 apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts create mode 100644 apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts create mode 100644 apps/adk-ui/src/api/a2a/streaming.ts create mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts create mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts create mode 100644 apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts diff --git a/apps/adk-cli/src/kagenti_cli/api.py b/apps/adk-cli/src/kagenti_cli/api.py index ba62c71c..bf16d964 100644 --- a/apps/adk-cli/src/kagenti_cli/api.py +++ b/apps/adk-cli/src/kagenti_cli/api.py @@ -16,7 +16,8 @@ import httpx import openai import pydantic -from a2a.client import A2AClientError, Client, ClientConfig, ClientFactory +from a2a.client import A2AClientError, Client, ClientCallContext, ClientConfig, ClientFactory +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import AgentCard from kagenti_adk.platform.context import ContextToken from google.protobuf.json_format import MessageToDict @@ -131,8 +132,18 @@ class OpenAPISchema(pydantic.BaseModel): return None +def make_extension_context(extensions: list[str] | None = None) -> ClientCallContext | None: + """Create a ClientCallContext with extension URIs as service parameters.""" + if not extensions: + return None + return ClientCallContext(service_parameters={HTTP_EXTENSION_HEADER: ",".join(extensions)}) + + @asynccontextmanager -async def a2a_client(agent_card: AgentCard, context_token: ContextToken) -> AsyncIterator[Client]: +async def a2a_client( + agent_card: AgentCard, + context_token: ContextToken, +) -> AsyncIterator[Client]: try: async with httpx.AsyncClient( headers={"Authorization": f"Bearer {context_token.token.get_secret_value()}"}, @@ -140,7 +151,7 @@ async def a2a_client(agent_card: AgentCard, context_token: ContextToken) -> Asyn timeout=timedelta(hours=1).total_seconds(), ) as httpx_client: yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create( - card=agent_card + card=agent_card, ) except A2AClientError as ex: card_data = json.dumps( diff --git a/apps/adk-cli/src/kagenti_cli/commands/agent.py b/apps/adk-cli/src/kagenti_cli/commands/agent.py index 635a6a19..6e88a5e7 100644 --- a/apps/adk-cli/src/kagenti_cli/commands/agent.py +++ b/apps/adk-cli/src/kagenti_cli/commands/agent.py @@ -22,6 +22,7 @@ Message, Part, Role, + SendMessageRequest, TaskState, ) from kagenti_adk.a2a.extensions import ( @@ -35,9 +36,18 @@ LLMServiceExtensionSpec, PlatformApiExtensionClient, PlatformApiExtensionSpec, - TrajectoryExtensionClient, + Trajectory, TrajectoryExtensionSpec, ) +from kagenti_adk.a2a.extensions.streaming import ( + ArtifactDelta, + MetadataDelta, + PartDelta, + StateChange, + StreamingExtensionClient, + StreamingExtensionSpec, + TextDelta, +) from kagenti_adk.a2a.extensions.common.form import ( CheckboxField, CheckboxFieldValue, @@ -115,7 +125,7 @@ from rich.markdown import Markdown from rich.table import Column -from kagenti_cli.api import a2a_client +from kagenti_cli.api import a2a_client, make_extension_context from kagenti_cli.async_typer import AsyncTyper, console, create_table, err_console from kagenti_cli.server_utils import announce_server_action, confirm_server_action from kagenti_cli.utils import ( @@ -199,14 +209,18 @@ async def _discover_agent_card(location: str) -> AgentCard: @app.command("add") async def add_agent( - location: typing.Annotated[ - str | None, typer.Argument(help="Agent image or network URL") + location: typing.Annotated[str | None, typer.Argument(help="Agent image or network URL")] = None, + name: typing.Annotated[ + str | None, typer.Option("--name", "-n", help="Agent name (default: derived from image)") ] = None, - name: typing.Annotated[str | None, typer.Option("--name", "-n", help="Agent name (default: derived from image)")] = None, namespace: typing.Annotated[str, typer.Option(help="Target Kubernetes namespace")] = "team1", port: typing.Annotated[int, typer.Option(help="Agent service port")] = 8080, - env: typing.Annotated[list[str] | None, typer.Option("--env", "-e", help="Environment variable in KEY=VALUE format (repeatable)")] = None, - env_file: typing.Annotated[str | None, typer.Option("--env-file", help="Path to env file (KEY=VALUE per line)")] = None, + env: typing.Annotated[ + list[str] | None, typer.Option("--env", "-e", help="Environment variable in KEY=VALUE format (repeatable)") + ] = None, + env_file: typing.Annotated[ + str | None, typer.Option("--env-file", help="Path to env file (KEY=VALUE per line)") + ] = None, yes: typing.Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompts.")] = False, ) -> None: """Add an agent by container image or network URL. [Admin only]""" @@ -394,9 +408,7 @@ async def update_agent( search_path: typing.Annotated[ str | None, typer.Argument(help="Short ID, agent name or part of the provider location of agent to replace") ] = None, - location: typing.Annotated[ - str | None, typer.Argument(help="New agent location (network URL)") - ] = None, + location: typing.Annotated[str | None, typer.Argument(help="New agent location (network URL)")] = None, yes: typing.Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompts.")] = False, ) -> None: """Update an agent's location. [Admin only]""" @@ -760,9 +772,9 @@ async def _run_agent( console_status_stopped = False log_type = None + pending_form_metadata = None - trajectory_spec = TrajectoryExtensionSpec.from_agent_card(agent_card) - trajectory_extension = TrajectoryExtensionClient(trajectory_spec) if trajectory_spec else None + has_trajectory = TrajectoryExtensionSpec.from_agent_card(agent_card) is not None llm_spec = LLMServiceExtensionSpec.from_agent_card(agent_card) embedding_spec = EmbeddingServiceExtensionSpec.from_agent_card(agent_card) platform_extension_spec = PlatformApiExtensionSpec.from_agent_card(agent_card) @@ -860,30 +872,39 @@ async def _run_agent( metadata=metadata, ) - stream = client.send_message(msg) + streaming_spec = StreamingExtensionSpec.from_agent_card(agent_card) + streaming = StreamingExtensionClient(streaming_spec or StreamingExtensionSpec()) + extension_context = make_extension_context([ext.uri for ext in agent_card.capabilities.extensions or []]) + stream = client.send_message(SendMessageRequest(message=msg), context=extension_context) while True: - async for response, task in stream: + async for delta, task_obj in streaming.stream(stream): if not console_status_stopped: console_status_stopped = True console_status.stop() - task_id = task.id if task else task_id - - if response.HasField("status_update"): - update = response.status_update - status = update.status - state = status.state - message = status.message if status.HasField("message") else None - - if state == TaskState.TASK_STATE_COMPLETED: - console.print() # Add newline after completion - return - - elif state in (TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_SUBMITTED): - # Handle streaming content during working state - if message: - if trajectory_extension and (trajectory := trajectory_extension.parse_server_metadata(message)): + task_id = task_obj.id if task_obj else task_id + + match delta: + case TextDelta(delta=text): + if log_type: + err_console.print() + log_type = None + console.print(text, end="") + + case PartDelta(part=part): + if log_type: + err_console.print() + log_type = None + if "text" in part: + console.print(part["text"], end="") + + case MetadataDelta(metadata=meta): + if FormRequestExtensionSpec.URI in meta: + pending_form_metadata = meta[FormRequestExtensionSpec.URI] + if has_trajectory and TrajectoryExtensionSpec.URI in meta: + for entry in meta[TrajectoryExtensionSpec.URI]: + trajectory = Trajectory.model_validate(entry) if update_kind := trajectory.title: if update_kind != log_type: if log_type is not None: @@ -891,109 +912,106 @@ async def _run_agent( err_console.print(f"{update_kind}: ", style="dim", end="") log_type = update_kind err_console.print(trajectory.content or "", style="dim", end="") - else: - # This is regular message content - if log_type: - console.print() - log_type = None - for part in message.parts: - if part.HasField("text"): - console.print(part.text, end="") - - elif state == TaskState.TASK_STATE_INPUT_REQUIRED: - if handle_input is None: - raise ValueError("Agent requires input but no input handler provided") - - if form_metadata := ( - MessageToDict(message.metadata).get(FormRequestExtensionSpec.URI) - if message and message.metadata - else None - ): - stream = client.send_message( - Message( - message_id=str(uuid4()), - parts=[], - role=Role.ROLE_USER, - task_id=task_id, - context_id=context_token.context_id, - metadata={ - FormRequestExtensionSpec.URI: ( - await _ask_form_questions(FormRender.model_validate(form_metadata)) - ).model_dump(mode="json") - }, + + case ArtifactDelta(event=artifact_event): + artifact = artifact_event.artifact + if dump_files_path is None: + continue + dump_files_path.mkdir(parents=True, exist_ok=True) + full_path = dump_files_path / (artifact.name or "unnamed").lstrip("/") + full_path.resolve().relative_to(dump_files_path.resolve()) + full_path.parent.mkdir(parents=True, exist_ok=True) + try: + for part in artifact.parts[:1]: + if part.HasField("raw"): + full_path.write_bytes(part.raw) + console.print(f"📁 Saved {full_path}") + elif part.HasField("url"): + uri = part.url + if uri.startswith("agentstack://"): + async with File.load_content(uri.removeprefix("agentstack://")) as file: + full_path.write_bytes(file.content) + else: + async with httpx.AsyncClient() as httpx_client: + full_path.write_bytes((await httpx_client.get(uri)).content) + console.print(f"📁 Saved {full_path}") + elif part.HasField("text"): + full_path.write_text(part.text) + else: + console.print(f"⚠️ Artifact part {type(part).__name__} is not supported") + if len(artifact.parts) > 1: + console.print("⚠️ Artifact with more than 1 part are not supported.") + except ValueError: + console.print(f"⚠️ Skipping artifact {artifact.name} - outside dump directory") + + case StateChange(state=state): + if log_type: + err_console.print() + log_type = None + if state == TaskState.TASK_STATE_COMPLETED: + console.print() # Add newline after completion + return + + elif state in (TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_SUBMITTED): + pass + + elif state == TaskState.TASK_STATE_INPUT_REQUIRED: + if handle_input is None: + raise ValueError("Agent requires input but no input handler provided") + + if pending_form_metadata: + stream = client.send_message( + SendMessageRequest( + message=Message( + message_id=str(uuid4()), + parts=[], + role=Role.ROLE_USER, + task_id=task_id, + context_id=context_token.context_id, + metadata={ + FormRequestExtensionSpec.URI: ( + await _ask_form_questions( + FormRender.model_validate(pending_form_metadata) + ) + ).model_dump(mode="json") + }, + ) + ), + context=extension_context, ) - ) - break + pending_form_metadata = None + break - text = "" - for part in message.parts if message else []: - if part.HasField("text"): - text = part.text - console.print(f"\n[bold]Agent requires your input[/bold]: {text}\n") - user_input = handle_input() - stream = client.send_message( - Message( - message_id=str(uuid4()), - parts=[Part(text=user_input)], - role=Role.ROLE_USER, - task_id=task_id, - context_id=context_token.context_id, + console.print("\n[bold]Agent requires your input[/bold]\n") + user_input = handle_input() + stream = client.send_message( + SendMessageRequest( + message=Message( + message_id=str(uuid4()), + parts=[Part(text=user_input)], + role=Role.ROLE_USER, + task_id=task_id, + context_id=context_token.context_id, + ) + ), + context=extension_context, ) - ) - break - - elif state in ( - TaskState.TASK_STATE_CANCELED, - TaskState.TASK_STATE_FAILED, - TaskState.TASK_STATE_REJECTED, - ): - error = "" - if message and message.parts and message.parts[0].HasField("text"): - error = message.parts[0].text - console.print(f"\n:boom: [red][bold]Task {TaskState.Name(state)}[/bold][/red]") - console.print(Markdown(error)) - return - - elif state == TaskState.TASK_STATE_AUTH_REQUIRED: - console.print("[yellow]Authentication required[/yellow]") - return + break - else: - console.print(f"[yellow]Unknown task status: {state}[/yellow]") + elif state in ( + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_REJECTED, + ): + console.print(f"\n:boom: [red][bold]Task {TaskState.Name(state)}[/bold][/red]") + return - elif response.HasField("artifact_update"): - artifact = response.artifact_update.artifact - if dump_files_path is None: - continue - dump_files_path.mkdir(parents=True, exist_ok=True) - full_path = dump_files_path / (artifact.name or "unnamed").lstrip("/") - full_path.resolve().relative_to(dump_files_path.resolve()) - full_path.parent.mkdir(parents=True, exist_ok=True) - try: - for part in artifact.parts[:1]: - if part.HasField("raw"): - full_path.write_bytes(part.raw) - console.print(f"📁 Saved {full_path}") - elif part.HasField("url"): - uri = part.url - if uri.startswith("adk://"): - async with File.load_content(uri.removeprefix("adk://")) as file: - full_path.write_bytes(file.content) - else: - async with httpx.AsyncClient() as httpx_client: - full_path.write_bytes((await httpx_client.get(uri)).content) - console.print(f"📁 Saved {full_path}") - elif part.HasField("text"): - full_path.write_text(part.text) - else: - console.print(f"⚠️ Artifact part {type(part).__name__} is not supported") - if len(artifact.parts) > 1: - console.print("⚠️ Artifact with more than 1 part are not supported.") - except ValueError: - console.print(f"⚠️ Skipping artifact {artifact.name} - outside dump directory") + elif state == TaskState.TASK_STATE_AUTH_REQUIRED: + console.print("[yellow]Authentication required[/yellow]") + return - else: - print(response) + else: + console.print(f"[yellow]Unknown task status: {state}[/yellow]") else: break # Stream ended normally @@ -1295,6 +1313,7 @@ async def run_agent( settings=settings_input, dump_files_path=dump_files, handle_input=handle_input, + ) console.print() turn_input = handle_input() diff --git a/apps/adk-cli/uv.lock b/apps/adk-cli/uv.lock index 4f0207f0..9e548f30 100644 --- a/apps/adk-cli/uv.lock +++ b/apps/adk-cli/uv.lock @@ -557,6 +557,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -591,6 +612,7 @@ source = { editable = "../adk-py" } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -598,6 +620,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -617,6 +640,7 @@ dependencies = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -624,6 +648,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-py/pyproject.toml b/apps/adk-py/pyproject.toml index 0424cce9..4e5d2e04 100644 --- a/apps/adk-py/pyproject.toml +++ b/apps/adk-py/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ "typing-extensions>=4.15.0", "opentelemetry-instrumentation-httpx>=0.60b1", "opentelemetry-instrumentation-openai>=0.52.3", + "asgiref>=3.11.0", + "jsonpatch>=1.33", ] [dependency-groups] diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py index 03613f05..76b147ac 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/oauth/oauth.py @@ -19,7 +19,7 @@ from typing_extensions import override from kagenti_adk.a2a.extensions.auth.oauth.storage import MemoryTokenStorageFactory, TokenStorageFactory -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.a2a.types import AgentMessage, AuthRequired, Metadata, RunYieldResume from kagenti_adk.util.pydantic import REVEAL_SECRETS, SecureBaseModel @@ -50,7 +50,6 @@ "OAuthFulfillment", ] -_DEFAULT_DEMAND_NAME = "default" class AuthRequest(SecureBaseModel): @@ -74,19 +73,22 @@ class OAuthExtensionParams(pydantic.BaseModel): """Server requests that the agent requires to be provided by the client.""" -class OAuthExtensionSpec(BaseExtensionSpec[OAuthExtensionParams]): - URI: str = "https://a2a-extensions.adk.kagenti.dev/auth/oauth/v1" - - @classmethod - def single_demand(cls, name: str = _DEFAULT_DEMAND_NAME) -> Self: - return cls(params=OAuthExtensionParams(oauth_demands={name: OAuthDemand()})) - - class OAuthExtensionMetadata(pydantic.BaseModel): oauth_fulfillments: dict[str, OAuthFulfillment] = {} """Provided servers corresponding to the server requests.""" +class OAuthExtensionSpec(BaseExtensionSpec[OAuthExtensionParams, OAuthExtensionMetadata]): + URI: str = "https://a2a-extensions.adk.kagenti.dev/auth/oauth/v1" + + @classmethod + def single_demand(cls, name: str = DEFAULT_DEMAND_NAME, default: OAuthFulfillment | None = None) -> Self: + return cls( + params=OAuthExtensionParams(oauth_demands={name: OAuthDemand()}), + default=OAuthExtensionMetadata(oauth_fulfillments={name: default}) if default else None, + ) + + class OAuthExtensionServer(BaseExtensionServer[OAuthExtensionSpec, OAuthExtensionMetadata]): context: RunContext token_storage_factory: TokenStorageFactory @@ -105,7 +107,7 @@ def _get_fulfillment_for_resource(self, resource_url: pydantic.AnyUrl): raise RuntimeError("No fulfillments found") fulfillment = self.data.oauth_fulfillments.get(str(resource_url)) or self.data.oauth_fulfillments.get( - _DEFAULT_DEMAND_NAME + DEFAULT_DEMAND_NAME ) if fulfillment: return fulfillment diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py index 0193c295..f03660c0 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/auth/secrets/secrets.py @@ -13,7 +13,7 @@ from opentelemetry import trace from typing_extensions import override -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.a2a.types import AgentMessage, AuthRequired from kagenti_adk.util.pydantic import REDACT_SECRETS, REVEAL_SECRETS, SecureBaseModel from kagenti_adk.util.telemetry import flatten_dict @@ -64,15 +64,22 @@ class SecretsServiceExtensionMetadata(pydantic.BaseModel): secret_fulfillments: dict[str, SecretFulfillment] = {} -class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | None]): +class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | None, SecretsServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/auth/secrets/v1" @classmethod - def single_demand(cls, name: str, key: str | None = None, description: str | None = None) -> Self: + def single_demand( + cls, + name: str, + key: str = DEFAULT_DEMAND_NAME, + description: str | None = None, + default: SecretFulfillment | None = None, + ) -> Self: return cls( params=SecretsServiceExtensionParams( - secret_demands={key or "default": SecretDemand(description=description, name=name)} - ) + secret_demands={key: SecretDemand(description=description, name=name)} + ), + default=SecretsServiceExtensionMetadata(secret_fulfillments={key: default}) if default else None, ) diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py index e433722e..3135357f 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/base.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import logging import typing from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -29,8 +30,10 @@ ) ParamsT = typing.TypeVar("ParamsT") -MetadataFromClientT = typing.TypeVar("MetadataFromClientT") -MetadataFromServerT = typing.TypeVar("MetadataFromServerT") +MetadataFromClientT = typing.TypeVar("MetadataFromClientT", bound=BaseModel | NoneType) +MetadataFromServerT = typing.TypeVar("MetadataFromServerT", bound=BaseModel | list | NoneType) + +logger = logging.getLogger(__name__) if typing.TYPE_CHECKING: @@ -39,16 +42,18 @@ A2A_EXTENSION_URI = "a2a_extension.uri" A2A_EXTENSION_METADATA_RECEIVED_EVENT = "a2a_extension.metadata.received" +DEFAULT_DEMAND_NAME = "default" def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]: - for base in getattr(cls, "__orig_bases__", ()): - if typing.get_origin(base) is base_class and (args := typing.get_args(base)): - return args + for klass in cls.__mro__: + for base in getattr(klass, "__orig_bases__", ()): + if typing.get_origin(base) is base_class and (args := typing.get_args(base)): + return args raise TypeError(f"Missing Params type for {cls.__name__}") -class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT]): +class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT, MetadataFromClientT]): """ Base class for an A2A extension handler. @@ -76,12 +81,18 @@ def __init_subclass__(cls, **kwargs): Params from the agent card. """ - def __init__(self, params: ParamsT, required: bool = False) -> None: + default: MetadataFromClientT | None = None + """ + Default metadata to use if the client does not provide any. + """ + + def __init__(self, params: ParamsT, required: bool = False, default: MetadataFromClientT | None = None) -> None: """ Agent should construct an extension instance using the constructor. """ self.params = params self.required = required + self.default = default @classmethod def from_agent_card(cls: type[typing.Self], agent: AgentCard) -> typing.Self | None: @@ -111,9 +122,9 @@ def to_agent_card_extensions(self, *, required: bool | None = None) -> list[Agen ] -class NoParamsBaseExtensionSpec(BaseExtensionSpec[NoneType]): - def __init__(self, required: bool = False): - super().__init__(None, required) +class NoParamsBaseExtensionSpec(typing.Generic[MetadataFromClientT], BaseExtensionSpec[NoneType, MetadataFromClientT]): + def __init__(self, required: bool = False, default: MetadataFromClientT | None = None): + super().__init__(None, required, default) @classmethod @override @@ -123,7 +134,7 @@ def from_agent_card(cls, agent: AgentCard) -> typing.Self | None: return None -ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec[typing.Any]) +ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec[typing.Any, typing.Any]) class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromClientT]): @@ -131,6 +142,8 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl Type of the extension metadata, attached to messages. """ + _is_active: bool = False + def __init_subclass__(cls: type[Self], **kwargs): super().__init_subclass__(**kwargs) @@ -150,11 +163,28 @@ def current(cls) -> Self | None: return cls._context_var.get() @property - def data(self) -> MetadataFromClientT | None: - return self._metadata_from_client + def data(self) -> MetadataFromClientT: + if self.MetadataFromClient is NoneType: + return None # type: ignore + + if self._metadata_from_client: + if not self._is_active: + logger.warning("Extension metadata received but extension is not active.") + return self._metadata_from_client + + if self.spec.default is not None: + return self.spec.default + + if not self._is_active: + raise AttributeError(f"Cannot access 'data' attribute: extension '{self.spec.URI}' is not active.") + + raise AttributeError( + f"Extension '{self.spec.URI}' is active but no metadata provided and no default available." + ) def __bool__(self): - return bool(self.data) + # fallback - if we receive metadata but not an extension activation header + return bool(self._is_active or self._metadata_from_client) def __init__(self, spec: ExtensionSpecT, *args, **kwargs) -> None: self.spec = spec @@ -181,6 +211,10 @@ def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, attributes=flatten_dict(self._metadata_from_client.model_dump(context={REDACT_SECRETS: True})), ) + if not self._is_active and request_context.call_context: + self._is_active = self.spec.URI in request_context.call_context.requested_extensions + request_context.call_context.activated_extensions.add(self.spec.URI) + def _fork(self) -> typing.Self: """Creates a clone of this instance with the same arguments as the original""" return type(self)(self.spec, *self._args, **self._kwargs) diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py index a7e60cda..daf6b659 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/interactions/approval.py @@ -115,14 +115,14 @@ class ApprovalExtensionParams(BaseModel): pass -class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]): - URI: str = "https://a2a-extensions.adk.kagenti.dev/interactions/approval/v1" - - class ApprovalExtensionMetadata(BaseModel): pass +class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams, ApprovalExtensionMetadata]): + URI: str = "https://a2a-extensions.adk.kagenti.dev/interactions/approval/v1" + + class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]): def create_request_message(self, *, request: ApprovalRequest): return AgentMessage( diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py index fb89457f..f867ff16 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/embedding.py @@ -13,7 +13,7 @@ from a2a.types import Message as A2AMessage from typing_extensions import override -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.util.pydantic import REVEAL_SECRETS, SecureBaseModel, redact_str __all__ = [ @@ -40,6 +40,7 @@ ] + class EmbeddingFulfillment(SecureBaseModel): identifier: str | None = None """ @@ -88,25 +89,32 @@ class EmbeddingServiceExtensionParams(pydantic.BaseModel): """Model requests that the agent requires to be provided by the client.""" -class EmbeddingServiceExtensionSpec(BaseExtensionSpec[EmbeddingServiceExtensionParams]): +class EmbeddingServiceExtensionMetadata(pydantic.BaseModel): + embedding_fulfillments: dict[str, EmbeddingFulfillment] = {} + """Provided models corresponding to the model requests.""" + + +class EmbeddingServiceExtensionSpec( + BaseExtensionSpec[EmbeddingServiceExtensionParams, EmbeddingServiceExtensionMetadata] +): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/embedding/v1" @classmethod def single_demand( - cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = () + cls, + name: str = DEFAULT_DEMAND_NAME, + description: str | None = None, + suggested: tuple[str, ...] = (), + default: EmbeddingFulfillment | None = None, ) -> Self: return cls( params=EmbeddingServiceExtensionParams( - embedding_demands={name or "default": EmbeddingDemand(description=description, suggested=suggested)} - ) + embedding_demands={name: EmbeddingDemand(description=description, suggested=suggested)} + ), + default=EmbeddingServiceExtensionMetadata(embedding_fulfillments={name: default}) if default else None, ) -class EmbeddingServiceExtensionMetadata(pydantic.BaseModel): - embedding_fulfillments: dict[str, EmbeddingFulfillment] = {} - """Provided models corresponding to the model requests.""" - - class EmbeddingServiceExtensionServer( BaseExtensionServer[EmbeddingServiceExtensionSpec, EmbeddingServiceExtensionMetadata] ): diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py index 90569087..93735b5b 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/form.py @@ -31,7 +31,7 @@ class FormServiceExtensionParams(BaseModel): form_demands: FormDemands -class FormServiceExtensionSpec(BaseExtensionSpec[FormServiceExtensionParams]): +class FormServiceExtensionSpec(BaseExtensionSpec[FormServiceExtensionParams, FormServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/form/v1" @classmethod diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py index 2fc16a35..c51ae18a 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/llm.py @@ -14,7 +14,7 @@ from a2a.types import Message as A2AMessage from typing_extensions import override -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.util.pydantic import REVEAL_SECRETS, SecureBaseModel, redact_str __all__ = [ @@ -41,6 +41,7 @@ ] + class LLMFulfillment(SecureBaseModel): identifier: str | None = None """ @@ -89,25 +90,30 @@ class LLMServiceExtensionParams(pydantic.BaseModel): """Model requests that the agent requires to be provided by the client.""" -class LLMServiceExtensionSpec(BaseExtensionSpec[LLMServiceExtensionParams]): +class LLMServiceExtensionMetadata(pydantic.BaseModel): + llm_fulfillments: dict[str, LLMFulfillment] = {} + """Provided models corresponding to the model requests.""" + + +class LLMServiceExtensionSpec(BaseExtensionSpec[LLMServiceExtensionParams, LLMServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/llm/v1" @classmethod def single_demand( - cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = () + cls, + name: str = DEFAULT_DEMAND_NAME, + description: str | None = None, + suggested: tuple[str, ...] = (), + default: LLMFulfillment | None = None, ) -> Self: return cls( params=LLMServiceExtensionParams( - llm_demands={name or "default": LLMDemand(description=description, suggested=suggested)} - ) + llm_demands={name: LLMDemand(description=description, suggested=suggested)} + ), + default=LLMServiceExtensionMetadata(llm_fulfillments={name: default}) if default else None, ) -class LLMServiceExtensionMetadata(pydantic.BaseModel): - llm_fulfillments: dict[str, LLMFulfillment] = {} - """Provided models corresponding to the model requests.""" - - class LLMServiceExtensionServer(BaseExtensionServer[LLMServiceExtensionSpec, LLMServiceExtensionMetadata]): @override def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext): diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py index 2559e658..19212611 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/mcp.py @@ -18,7 +18,7 @@ from typing_extensions import override from kagenti_adk.a2a.extensions.auth.oauth.oauth import OAuthExtensionServer -from kagenti_adk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from kagenti_adk.a2a.extensions.base import DEFAULT_DEMAND_NAME, BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec from kagenti_adk.a2a.extensions.services.platform import PlatformApiExtensionServer from kagenti_adk.platform.client import get_platform_client from kagenti_adk.util.logging import logger @@ -55,7 +55,6 @@ _TRANSPORT_TYPES = Literal["streamable_http", "stdio"] -_DEFAULT_DEMAND_NAME = "default" _DEFAULT_ALLOWED_TRANSPORTS: list[_TRANSPORT_TYPES] = ["streamable_http"] @@ -115,16 +114,22 @@ class MCPServiceExtensionParams(pydantic.BaseModel): """Server requests that the agent requires to be provided by the client.""" -class MCPServiceExtensionSpec(BaseExtensionSpec[MCPServiceExtensionParams]): +class MCPServiceExtensionMetadata(pydantic.BaseModel): + mcp_fulfillments: dict[str, MCPFulfillment] = {} + """Provided servers corresponding to the server requests.""" + + +class MCPServiceExtensionSpec(BaseExtensionSpec[MCPServiceExtensionParams, MCPServiceExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/mcp/v1" @classmethod def single_demand( cls, - name: str = _DEFAULT_DEMAND_NAME, + name: str = DEFAULT_DEMAND_NAME, description: str | None = None, suggested: tuple[str, ...] = (), allowed_transports: list[_TRANSPORT_TYPES] | None = None, + default: MCPFulfillment | None = None, ) -> Self: return cls( params=MCPServiceExtensionParams( @@ -135,15 +140,11 @@ def single_demand( allowed_transports=allowed_transports or _DEFAULT_ALLOWED_TRANSPORTS, ) } - ) + ), + default=MCPServiceExtensionMetadata(mcp_fulfillments={name: default}) if default else None, ) -class MCPServiceExtensionMetadata(pydantic.BaseModel): - mcp_fulfillments: dict[str, MCPFulfillment] = {} - """Provided servers corresponding to the server requests.""" - - class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCPServiceExtensionMetadata]): @override def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext): @@ -173,7 +174,7 @@ def parse_client_metadata(self, message: A2AMessage) -> MCPServiceExtensionMetad return metadata @asynccontextmanager - async def create_client(self, demand: str = _DEFAULT_DEMAND_NAME): + async def create_client(self, demand: str = DEFAULT_DEMAND_NAME): fulfillment = self.data.mcp_fulfillments.get(demand) if self.data else None if not fulfillment: diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py index e64ea62c..d4a2f450 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/services/platform.py @@ -66,7 +66,7 @@ class PlatformApiExtensionParams(pydantic.BaseModel): auto_use: bool = True -class PlatformApiExtensionSpec(BaseExtensionSpec[PlatformApiExtensionParams]): +class PlatformApiExtensionSpec(BaseExtensionSpec[PlatformApiExtensionParams, PlatformApiExtensionMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/platform_api/v1" def __init__(self, params: PlatformApiExtensionParams | None = None) -> None: @@ -152,7 +152,9 @@ class _PlatformSelfRegistrationExtensionParams(pydantic.BaseModel): self_registration_id: str -class _PlatformSelfRegistrationExtensionSpec(BaseExtensionSpec[_PlatformSelfRegistrationExtensionParams]): +class _PlatformSelfRegistrationExtensionSpec( + BaseExtensionSpec[_PlatformSelfRegistrationExtensionParams, _PlatformSelfRegistrationExtension] +): URI: str = "https://a2a-extensions.adk.kagenti.dev/services/platform-self-registration/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py new file mode 100644 index 00000000..2467f7a7 --- /dev/null +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/streaming.py @@ -0,0 +1,335 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + + +from __future__ import annotations + +from collections.abc import AsyncIterator +from enum import StrEnum +from types import NoneType +from typing import Any, cast + +from a2a.client.client import ClientEvent +from a2a.types import ( + AgentExtension, + Message, + TaskArtifactUpdateEvent, + TaskStatus, + TaskStatusUpdateEvent, +) +from google.protobuf.json_format import MessageToDict +from pydantic import BaseModel + +from kagenti_adk.a2a.extensions.base import ( + BaseExtensionClient, + BaseExtensionServer, + NoParamsBaseExtensionSpec, +) +from kagenti_adk.a2a.types import Metadata +from kagenti_adk.types import JsonPatch, JsonValue + + +class StreamingExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): + URI = "https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1" + DESCRIPTION = "Enables fine-grained streaming of token chunks." + + +class StreamOperations(StrEnum): + MESSAGE_UPDATE = "message_update" + + +class StreamingExtensionServer(BaseExtensionServer[StreamingExtensionSpec, NoneType]): + """ + Adds streaming support to the A2A protocol through TaskStatusUpdateEvent.metadata object. + + Updates are emitted as metadata objects containing JSON Patch (RFC 6902) operations, + extended with `str_ins` from json-crdt-patch for efficient text streaming. + + Supported operations: + - replace: initialize the message draft (root replace with message_id and parts) + - add: adding parts to the message + - str_ins: streaming individual text chunks (=llm tokens) + + The stream is a sequential log of patches, that are applied to a final message: + --- + update: {..., "https://.../streaming": {"message_update": { "op": "replace", "path": "", "value": {"message_id": "...", "parts": [{"text": "Hello "}]} }, "message_id": "..."}} + update: {..., "https://.../streaming": {"message_update": { "op": "str_ins", "path": "/parts/0/text", "pos": 6, "value": "world" }, "message_id": "..."}} + update: {..., "https://.../streaming": {"message_update": { "op": "str_ins", "path": "/parts/0/text", "pos": 11, "value": "!" }, "message_id": "..."}} + """ + + def to_metadata(self, patches: JsonPatch, message_id: str | None = None) -> Metadata: + payload: dict[str, Any] = {StreamOperations.MESSAGE_UPDATE: patches} + if message_id is not None: + payload["message_id"] = message_id + return Metadata({self.spec.URI: cast(JsonValue, payload)}) + + +# --- Client-side delta types --- + + +class TextDelta(BaseModel): + """A text chunk appended to an existing text part.""" + + part_index: int + delta: str + + +class PartDelta(BaseModel): + """A new part was added to the message.""" + + part_index: int + part: dict[str, Any] + + +class MetadataDelta(BaseModel): + """Message metadata was added or updated.""" + + metadata: dict[str, Any] + + +class ArtifactDelta(BaseModel, arbitrary_types_allowed=True): + """An artifact update event.""" + + event: TaskArtifactUpdateEvent + + +class StateChange(BaseModel, arbitrary_types_allowed=True): + """A task state transition (WORKING, COMPLETED, INPUT_REQUIRED, etc.).""" + + state: int # TaskState enum value + message: Message | None = None + + +StreamDelta = TextDelta | PartDelta | MetadataDelta | ArtifactDelta | StateChange + + +class StreamingExtensionClient(BaseExtensionClient[StreamingExtensionSpec, NoneType]): + """Client-side streaming consumer. + + Wraps raw A2A ``ClientEvent`` streams into a unified delta-based API. + Works identically whether the server supports the streaming extension or not: + + - **With streaming**: patches are applied incrementally; full messages whose + ``message_id`` was already streamed are suppressed. + - **Without streaming**: full messages are decomposed into ``PartDelta`` / + ``MetadataDelta`` / ``StateChange`` events so the consumer code is the same. + + Usage:: + + streaming = StreamingExtensionClient(spec) + async for delta, task in streaming.stream(client.send_message(msg)): + match delta: + case TextDelta(delta=text): + print(text, end="", flush=True) + case PartDelta(part=part): + ... + case StateChange(state=TaskState.TASK_STATE_COMPLETED): + print() + case ArtifactDelta(event=evt): + ... + """ + + def __init__(self, spec: StreamingExtensionSpec) -> None: + super().__init__(spec) + self._draft: dict[str, Any] = {} + self._message_id: str | None = None + self._streamed_messages: dict[str, int] = {} # message_id -> parts_count + + @property + def draft(self) -> dict[str, Any]: + """Current draft message built from applied patches.""" + return self._draft + + @property + def message_id(self) -> str | None: + """Current message_id being built from patches.""" + return self._message_id + + def to_agent_card_extensions(self, **kwargs) -> list[AgentExtension]: + return self.spec.to_agent_card_extensions(required=False) + + async def stream( + self, + events: AsyncIterator[ClientEvent], + ) -> AsyncIterator[tuple[StreamDelta, Any | None]]: + """Consume a raw A2A event stream and yield ``(delta, task)`` pairs. + + The method handles reconciliation automatically: messages that were + already streamed via patches are suppressed, and merged messages + (draft + explicit) only emit the new parts beyond the streamed prefix. + """ + async for response, task in events: + if response.HasField("artifact_update"): + yield ArtifactDelta(event=response.artifact_update), task + continue + + if not response.HasField("status_update"): + continue + + update: TaskStatusUpdateEvent = response.status_update + patch_data = self._extract_patches(update) + + if patch_data is not None: + # Streaming mode: apply patch and emit deltas + for delta in self._apply_and_emit(patch_data): + yield delta, task + continue + + # Non-streaming event (full message or state change) + status: TaskStatus = update.status + message: Message | None = status.message if status.HasField("message") else None + + if message and message.message_id and message.message_id in self._streamed_messages: + # This message was already streamed via patches + streamed_count = self._streamed_messages[message.message_id] + parts_list = list(message.parts) + + if len(parts_list) > streamed_count: + # Merged message: emit only the new parts beyond the streamed prefix + for i, part in enumerate(parts_list[streamed_count:], start=streamed_count): + yield PartDelta(part_index=i, part=MessageToDict(part)), task + + # Emit state change with the full message for reference + if MessageToDict(status): + yield StateChange(state=status.state, message=message), task + + # Clean up tracking for this message + del self._streamed_messages[message.message_id] + self._draft = {} + self._message_id = None + continue + + # Non-streamed message: decompose into deltas + if message and message.message_id: + for i, part in enumerate(message.parts): + yield PartDelta(part_index=i, part=MessageToDict(part)), task + meta = MessageToDict(message.metadata) + if meta: + yield MetadataDelta(metadata=meta), task + + if MessageToDict(status): + yield StateChange(state=status.state, message=message), task + + def text_delta(self, event: TaskStatusUpdateEvent) -> str | None: + """Extract text delta from streaming patch in event metadata. + + Returns the text content of the first text-producing patch, or ``None`` + if the event does not contain a text-producing streaming patch. + """ + patches = self._extract_patches(event) + if patches is None: + return None + for patch in patches: + if (text := self._patch_text_delta(patch)) is not None: + return text + return None + + def apply_patch(self, event: TaskStatusUpdateEvent) -> dict[str, Any] | None: + """Apply streaming patches to internal draft. Returns current draft state, or ``None`` if no patches.""" + patches = self._extract_patches(event) + if patches is None: + return None + self._apply_patches_to_draft(patches) + return self._draft + + def _extract_patches(self, event: TaskStatusUpdateEvent) -> list[dict[str, Any]] | None: + """Extract the streaming patch list from event metadata, if present.""" + if not event.HasField("metadata"): + return None + meta = MessageToDict(event.metadata) + if not meta or self.spec.URI not in meta: + return None + ext_data = meta[self.spec.URI] + if not isinstance(ext_data, dict): + return None + patches = ext_data.get(StreamOperations.MESSAGE_UPDATE) + if not isinstance(patches, list): + return None + + # Track message_id from extension metadata + msg_id = ext_data.get("message_id") + if msg_id and isinstance(msg_id, str): + self._message_id = msg_id + + return patches + + def _apply_patches_to_draft(self, patches: list[dict[str, Any]]) -> None: + """Apply a list of patch operations to the internal draft.""" + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + + self._draft = ExtendedJsonPatch(patches).apply(self._draft) + # Update tracking + if self._message_id: + parts = self._draft.get("parts", []) + self._streamed_messages[self._message_id] = len(parts) + + def _apply_and_emit(self, patches: list[dict[str, Any]]) -> list[StreamDelta]: + """Apply patches and return the corresponding deltas.""" + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + + deltas: list[StreamDelta] = [] + metadata_ops: list[dict[str, Any]] = [] + parts_before = len(self._draft.get("parts", [])) + + self._apply_patches_to_draft(patches) + + add_parts_seen = 0 + for patch in patches: + op = patch.get("op") + path = patch.get("path", "") + value = patch.get("value") + + if op == "str_ins": + segments = path.split("/") + if len(segments) >= 3 and segments[1] == "parts": + part_index = int(segments[2]) + deltas.append(TextDelta(part_index=part_index, delta=patch.get("value", ""))) + + elif op == "replace" and path == "": + if isinstance(value, dict): + for i, part in enumerate(value.get("parts", [])): + deltas.append(PartDelta(part_index=i, part=part)) + if meta := value.get("metadata"): + metadata_ops.append({"op": "replace", "path": "", "value": meta}) + + elif op == "add" and path == "/parts/-": + part_index = parts_before + add_parts_seen + add_parts_seen += 1 + if isinstance(value, dict): + deltas.append(PartDelta(part_index=part_index, part=value)) + + elif path.startswith("/metadata"): + # Strip /metadata prefix and collect for incremental application + metadata_ops.append({**patch, "path": path[len("/metadata"):]}) + + if metadata_ops: + incremental = ExtendedJsonPatch(metadata_ops).apply({}) + if incremental: + deltas.append(MetadataDelta(metadata=incremental)) + + return deltas + + @staticmethod + def _patch_text_delta(patch: dict[str, Any]) -> str | None: + """Extract a text delta from a single patch operation.""" + op = patch.get("op") + value = patch.get("value") + path = patch.get("path", "") + + if op == "str_ins": + return value if isinstance(value, str) else None + + if op == "replace" and path == "": + # Root replace -- extract text from first part if any + if isinstance(value, dict): + parts = value.get("parts", []) + if parts and isinstance(parts[0], dict) and "text" in parts[0]: + return parts[0]["text"] + return None + + if op == "add" and "/parts/" in path: + if isinstance(value, dict) and "text" in value: + return value["text"] + return None + + return None diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py index d0bfdc44..7364ae53 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/tools/call.py @@ -81,14 +81,14 @@ class ToolCallExtensionParams(BaseModel): pass -class ToolCallExtensionSpec(BaseExtensionSpec[ToolCallExtensionParams]): - URI: str = "https://a2a-extensions.adk.kagenti.dev/tools/call/v1" - - class ToolCallExtensionMetadata(BaseModel): pass +class ToolCallExtensionSpec(BaseExtensionSpec[ToolCallExtensionParams, ToolCallExtensionMetadata]): + URI: str = "https://a2a-extensions.adk.kagenti.dev/tools/call/v1" + + class ToolCallExtensionServer(BaseExtensionServer[ToolCallExtensionSpec, ToolCallExtensionMetadata]): def create_request_message(self, *, request: ToolCallRequest): return AgentMessage( diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py index 85e56c35..907bfe67 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/__init__.py @@ -2,11 +2,114 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from .agent_detail import ( + AgentDetail, + AgentDetailContributor, + AgentDetailExtensionClient, + AgentDetailExtensionServer, + AgentDetailExtensionSpec, + AgentDetailTool, + EnvVar, +) +from .canvas import ( + CanvasEditRequest, + CanvasEditRequestMetadata, + CanvasExtensionServer, + CanvasExtensionSpec, +) +from .citation import ( + Citation, + CitationExtensionClient, + CitationExtensionServer, + CitationExtensionSpec, + CitationMetadata, +) +from .error import ( + DEFAULT_ERROR_EXTENSION, + Error, + ErrorContext, + ErrorExtensionClient, + ErrorExtensionParams, + ErrorExtensionServer, + ErrorExtensionSpec, + ErrorGroup, + ErrorMetadata, + get_error_extension_context, + use_error_extension_context, +) +from .form_request import ( + FormRequestExtensionClient, + FormRequestExtensionServer, + FormRequestExtensionSpec, +) +from .settings import ( + AgentRunSettings, + CheckboxField, + CheckboxFieldValue, + CheckboxGroupField, + CheckboxGroupFieldValue, + OptionItem, + SettingsExtensionClient, + SettingsExtensionServer, + SettingsExtensionSpec, + SettingsFieldValue, + SettingsRender, + SingleSelectField, + SingleSelectFieldValue, +) +from .trajectory import ( + Trajectory, + TrajectoryExtensionClient, + TrajectoryExtensionServer, + TrajectoryExtensionSpec, +) -from .agent_detail import * -from .canvas import * -from .citation import * -from .error import * -from .form_request import * -from .settings import * -from .trajectory import * +__all__ = [ + "AgentDetail", + "AgentDetailContributor", + "AgentDetailExtensionClient", + "AgentDetailExtensionServer", + "AgentDetailExtensionSpec", + "AgentDetailTool", + "AgentRunSettings", + "CanvasEditRequest", + "CanvasEditRequestMetadata", + "CanvasExtensionServer", + "CanvasExtensionSpec", + "CheckboxField", + "CheckboxFieldValue", + "CheckboxGroupField", + "CheckboxGroupFieldValue", + "Citation", + "CitationExtensionClient", + "CitationExtensionServer", + "CitationExtensionSpec", + "CitationMetadata", + "DEFAULT_ERROR_EXTENSION", + "EnvVar", + "Error", + "ErrorContext", + "ErrorExtensionClient", + "ErrorExtensionParams", + "ErrorExtensionServer", + "ErrorExtensionSpec", + "ErrorGroup", + "ErrorMetadata", + "FormRequestExtensionClient", + "FormRequestExtensionServer", + "FormRequestExtensionSpec", + "OptionItem", + "SettingsExtensionClient", + "SettingsExtensionServer", + "SettingsExtensionSpec", + "SettingsFieldValue", + "SettingsRender", + "SingleSelectField", + "SingleSelectFieldValue", + "Trajectory", + "TrajectoryExtensionClient", + "TrajectoryExtensionServer", + "TrajectoryExtensionSpec", + "get_error_extension_context", + "use_error_extension_context", +] diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py index f867f693..2170a910 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/agent_detail.py @@ -44,7 +44,7 @@ class AgentDetail(pydantic.BaseModel, extra="allow"): variables: list[EnvVar] | None = None -class AgentDetailExtensionSpec(BaseExtensionSpec[AgentDetail]): +class AgentDetailExtensionSpec(BaseExtensionSpec[AgentDetail, NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/agent-detail/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py index 2d274743..a9371387 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/canvas.py @@ -57,7 +57,7 @@ def parse_artifact(cls, v): return v -class CanvasExtensionSpec(NoParamsBaseExtensionSpec): +class CanvasExtensionSpec(NoParamsBaseExtensionSpec[CanvasEditRequestMetadata]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/canvas/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py index 131d41dd..8a9d7eb0 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/citation.py @@ -54,13 +54,13 @@ class CitationMetadata(pydantic.BaseModel): citations: list[Citation] = pydantic.Field(default_factory=list) -class CitationExtensionSpec(NoParamsBaseExtensionSpec): +class CitationExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/citation/v1" class CitationExtensionServer(BaseExtensionServer[CitationExtensionSpec, NoneType]): def citation_metadata(self, *, citations: list[Citation]) -> Metadata: - return Metadata({self.spec.URI: CitationMetadata(citations=citations).model_dump(mode="json")}) + return Metadata({self.spec.URI: [c.model_dump(mode="json") for c in citations]}) def message( self, @@ -76,4 +76,4 @@ def message( ) -class CitationExtensionClient(BaseExtensionClient[CitationExtensionSpec, CitationMetadata]): ... +class CitationExtensionClient(BaseExtensionClient[CitationExtensionSpec, list[Citation]]): ... diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py index fbfbfdfb..3284ba26 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/error.py @@ -85,7 +85,7 @@ class ErrorExtensionParams(pydantic.BaseModel): include_stacktrace: bool = False -class ErrorExtensionSpec(BaseExtensionSpec[ErrorExtensionParams]): +class ErrorExtensionSpec(BaseExtensionSpec[ErrorExtensionParams, NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/error/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py index 260a7182..b175cf71 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/form_request.py @@ -37,7 +37,7 @@ T = TypeVar("T") -class FormRequestExtensionSpec(NoParamsBaseExtensionSpec): +class FormRequestExtensionSpec(NoParamsBaseExtensionSpec[FormResponse]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/form_request/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py index 9de81a45..538fd066 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/settings.py @@ -72,7 +72,7 @@ class AgentRunSettings(BaseModel): @deprecated("Use FormServiceExtensionSpec.demand_settings() instead") -class SettingsExtensionSpec(BaseExtensionSpec[SettingsRender | None]): +class SettingsExtensionSpec(BaseExtensionSpec[SettingsRender | None, AgentRunSettings]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/settings/v1" diff --git a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py index ac4d6eaf..2a06de95 100644 --- a/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py +++ b/apps/adk-py/src/kagenti_adk/a2a/extensions/ui/trajectory.py @@ -42,7 +42,7 @@ class Trajectory(pydantic.BaseModel): group_id: str | None = None -class TrajectoryExtensionSpec(NoParamsBaseExtensionSpec): +class TrajectoryExtensionSpec(NoParamsBaseExtensionSpec[NoneType]): URI: str = "https://a2a-extensions.adk.kagenti.dev/ui/trajectory/v1" @@ -51,7 +51,11 @@ def trajectory_metadata( self, *, title: str | None = None, content: str | None = None, group_id: str | None = None ) -> Metadata: return Metadata( - {self.spec.URI: Trajectory(title=title, content=content, group_id=group_id).model_dump(mode="json")} + { + self.spec.URI: [ + Trajectory(title=title, content=content, group_id=group_id).model_dump(mode="json"), + ] + } ) def message( diff --git a/apps/adk-py/src/kagenti_adk/server/accumulator.py b/apps/adk-py/src/kagenti_adk/server/accumulator.py new file mode 100644 index 00000000..349afdb5 --- /dev/null +++ b/apps/adk-py/src/kagenti_adk/server/accumulator.py @@ -0,0 +1,219 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from typing import Self + +from a2a.types import ( + Message, + Part, + Role, + TaskStatus, + TaskStatusUpdateEvent, +) +from pydantic import BaseModel, Field + +from kagenti_adk.a2a.types import Metadata, RunYield +from kagenti_adk.types import JsonPatch, JsonPatchOp + + +class TextPartContext(BaseModel, arbitrary_types_allowed=True): + chunks: list[str] = Field(default_factory=list) + message_context: MessageContext + part_index: int + pos: int = 0 + + def add_chunk(self, chunk: str) -> JsonPatch: + from google.protobuf.json_format import MessageToDict + + self.chunks.append(chunk) + part_dict = MessageToDict(Part(text=chunk)) + if self.pos == 0: + if not self.message_context.initialized: + self.message_context.initialized = True + msg_id = str(self.message_context.message_id) + patch: JsonPatchOp = { + "op": "replace", + "path": "", + "value": {"message_id": msg_id, "parts": [part_dict]}, + } + else: + patch = {"op": "add", "path": "/parts/-", "value": part_dict} + else: + patch = {"op": "str_ins", "pos": self.pos, "path": f"/parts/{self.part_index}/text", "value": chunk} + self.pos += len(chunk) + return [patch] + + def build(self) -> Part: + return Part(text="".join(self.chunks)) + + +class MessageContext(BaseModel, arbitrary_types_allowed=True): + parts: list[Part] = Field(default_factory=list) + metadata: Metadata | None = None + initialized: bool = False + message_id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def add_metadata(self, metadata: Metadata) -> JsonPatch: + from kagenti_adk.server.jsonpatch_ext import make_patch + from kagenti_adk.server.utils import merge_metadata + + if self.metadata is None: + self.metadata = Metadata(metadata) + if not self.initialized: + self.initialized = True + return [ + { + "op": "replace", + "path": "", + "value": {"message_id": str(self.message_id), "parts": [], "metadata": dict(self.metadata)}, + } + ] + return [{"op": "add", "path": "/metadata", "value": dict(self.metadata)}] + old_metadata = dict(self.metadata) + self.metadata = merge_metadata(self.metadata, metadata) + new_metadata = dict(self.metadata) + + ops: JsonPatch = [] + for op in make_patch(old_metadata, new_metadata).patch: + rewritten: JsonPatchOp = {"op": op["op"], "path": f"/metadata{op['path']}"} + if "value" in op: + rewritten["value"] = op["value"] + if "pos" in op: + rewritten["pos"] = op["pos"] + ops.append(rewritten) + return ops + + def add_part(self, part: Part) -> JsonPatch: + from google.protobuf.json_format import MessageToDict + + self.parts.append(part) + part_dict = MessageToDict(part) + if not self.initialized: + self.initialized = True + return [{"op": "replace", "path": "", "value": {"message_id": str(self.message_id), "parts": [part_dict]}}] + return [{"op": "add", "path": "/parts/-", "value": part_dict}] + + def build(self) -> Message: + m = Message(message_id=str(self.message_id), role=Role.ROLE_AGENT) + m.parts.extend(self.parts) + if self.metadata: + for k, v in self.metadata.items(): + m.metadata[k] = v + return m + + +@dataclass +class ProcessResult: + """Result of processing a yield through the accumulator.""" + + accumulated: bool = False + """True if the value was consumed by the accumulator (str, Part, Metadata).""" + + draft: Message | None = None + """Flushed message from accumulated state, when a non-accumulating yield triggers a flush.""" + + patch: JsonPatch | None = None + """Streaming patches to send as a partial update (JSON Patch operations).""" + + message_id: str | None = None + """The message_id of the current accumulation cycle, for client-side correlation.""" + + +class MessageAccumulator: + """Manages the streaming accumulation state machine. + + Accumulates string chunks, Parts, and Metadata into messages, + flushing when non-accumulating yields (Message, TaskStatus, etc.) arrive. + + The state machine has 3 levels: + - Base level (Self): no accumulation in progress + - MessageContext: accumulating parts and metadata into a message + - TextPartContext: accumulating string chunks into a single text Part + """ + + def __init__(self) -> None: + self._active: Self | MessageContext | TextPartContext = self + + @property + def active_context(self) -> MessageAccumulator | MessageContext | TextPartContext: + return self._active + + def process(self, value: RunYield) -> ProcessResult: + """Process a yield value through the state machine. + + Returns a ProcessResult describing what happened: + - accumulated=True: the value was consumed. patch may contain a streaming update. + - accumulated=False: the value is a "control" yield (Message, TaskStatus, etc.) + that the caller should handle. draft may contain a flushed accumulated message. + """ + match self._active: + case MessageAccumulator(): + return self._process_at_base_level(value) + case MessageContext() as ctx: + return self._process_at_message_level(ctx, value) + case TextPartContext() as ctx: + return self._process_at_text_part_level(ctx, value) + + def flush(self) -> Message | None: + """Flush any accumulated state into a message. Resets to base level.""" + match self._active: + case TextPartContext() as ctx: + ctx.message_context.add_part(ctx.build()) + msg = ctx.message_context.build() + case MessageContext() as ctx: + msg = ctx.build() + case _: + return None + self._active = self + return msg + + def _process_at_base_level(self, value: RunYield) -> ProcessResult: + match value: + case Message() | TaskStatus() | TaskStatusUpdateEvent(): + return ProcessResult(accumulated=False) + case _: + self._active = MessageContext() + return self._process_at_message_level(self._active, value) + + def _process_at_message_level(self, ctx: MessageContext, value: RunYield) -> ProcessResult: + msg_id = str(ctx.message_id) + match value: + case Part() as part: + patch = ctx.add_part(part) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case Metadata() as metadata: + patch = ctx.add_metadata(metadata) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case dict() as data: + patch = ctx.add_part(self._dict_to_part(data)) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case str(): + self._active = TextPartContext(message_context=ctx, part_index=len(ctx.parts)) + return self._process_at_text_part_level(self._active, value) + case _: + draft = ctx.build() + self._active = self + return ProcessResult(accumulated=False, draft=draft) + + @staticmethod + def _dict_to_part(data: dict) -> Part: + from google.protobuf.struct_pb2 import Struct, Value + + s = Struct() + s.update(data) + return Part(data=Value(struct_value=s)) + + def _process_at_text_part_level(self, ctx: TextPartContext, value: RunYield) -> ProcessResult: + msg_id = str(ctx.message_context.message_id) + match value: + case str(text): + patch = ctx.add_chunk(text) + return ProcessResult(accumulated=True, patch=patch, message_id=msg_id) + case _: + ctx.message_context.add_part(ctx.build()) + self._active = ctx.message_context + return self._process_at_message_level(ctx.message_context, value) diff --git a/apps/adk-py/src/kagenti_adk/server/agent.py b/apps/adk-py/src/kagenti_adk/server/agent.py index 94a03f9d..535066e2 100644 --- a/apps/adk-py/src/kagenti_adk/server/agent.py +++ b/apps/adk-py/src/kagenti_adk/server/agent.py @@ -35,7 +35,8 @@ from google.protobuf import message as _message from typing_extensions import override -from kagenti_adk.a2a.extensions import AgentDetailExtensionSpec, BaseExtensionServer +from kagenti_adk.a2a.extensions import BaseExtensionServer +from kagenti_adk.a2a.extensions.streaming import StreamingExtensionServer from kagenti_adk.a2a.extensions.ui.agent_detail import ( AgentDetail, AgentDetailExtensionSpec, @@ -45,12 +46,14 @@ get_error_extension_context, ) from kagenti_adk.a2a.types import Metadata, RunYield, RunYieldResume, validate_message +from kagenti_adk.server.accumulator import MessageAccumulator from kagenti_adk.server.constants import _DEFAULT_AGENT_INTERFACE, _DEFAULT_AGENT_SKILL, DEFAULT_IMPLICIT_EXTENSIONS from kagenti_adk.server.context import RunContext from kagenti_adk.server.dependencies import Dependency, Depends, extract_dependencies +from kagenti_adk.server.exceptions import InvalidYieldError from kagenti_adk.server.store.context_store import ContextStore -from kagenti_adk.server.utils import cancel_task -from kagenti_adk.types import A2ASecurity +from kagenti_adk.server.utils import cancel_task, merge_messages +from kagenti_adk.types import A2ASecurity, JsonPatch from kagenti_adk.util.logging import logger AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]] @@ -285,15 +288,17 @@ def decorator(fn: OriginalFnType) -> Agent: if inspect.isasyncgenfunction(fn): async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: + gen: AsyncGenerator[RunYield, RunYieldResume] = fn(*args, **kwargs) try: - gen: AsyncGenerator[RunYield, RunYieldResume] = fn(*args, **kwargs) value: RunYieldResume = None while True: - value = await _ctx.yield_async(await gen.asend(value)) + result = await gen.asend(value) + try: + value = await _ctx.yield_async(result) + except Exception as e: + value = await _ctx.yield_async(await gen.athrow(e)) except StopAsyncIteration: pass - except Exception as e: - await _ctx.yield_async(e) finally: _ctx.shutdown() @@ -301,24 +306,26 @@ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: try: - await _ctx.yield_async(await fn(*args, **kwargs)) - except Exception as e: - await _ctx.yield_async(e) + result = await fn(*args, **kwargs) + if result is not None: + await _ctx.yield_async(result) finally: _ctx.shutdown() elif inspect.isgeneratorfunction(fn): def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None: + gen: Generator[RunYield, RunYieldResume] = fn(*args, **kwargs) try: - gen: Generator[RunYield, RunYieldResume] = fn(*args, **kwargs) - value = None + value: RunYieldResume = None while True: - value = _ctx.yield_sync(gen.send(value)) + result = gen.send(value) + try: + value = _ctx.yield_sync(result) + except Exception as e: + value = _ctx.yield_sync(gen.throw(e)) except StopIteration: pass - except Exception as e: - _ctx.yield_sync(e) finally: _ctx.shutdown() @@ -329,9 +336,9 @@ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None: try: - _ctx.yield_sync(fn(*args, **kwargs)) - except Exception as e: - _ctx.yield_sync(e) + result = fn(*args, **kwargs) + if result is not None: + _ctx.yield_sync(result) finally: _ctx.shutdown() @@ -362,6 +369,7 @@ def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callabl self._on_finish: Callable[[], None] | None = on_finish self._working: bool = False self._dependency_container: ActiveDependenciesContainer | None = None + self._accumulator: MessageAccumulator = MessageAccumulator() @property def run_context(self) -> RunContext: @@ -440,16 +448,125 @@ async def cancel(self, request_context: RequestContext, event_queue: EventQueue) finally: await cancel_task(self._task) - def _with_context(self, message: Message | None = None) -> Message | None: - if message: - message.context_id = self.task_updater.context_id - message.task_id = self.task_updater.task_id - return message + def _prepare_message(self, message: Message | None = None, msg_draft: Message | None = None) -> Message | None: + for msg in (message, msg_draft): + if msg: + msg.context_id = self.task_updater.context_id + msg.task_id = self.task_updater.task_id + msgs = [m for m in (msg_draft, message) if m] + return merge_messages(*msgs) if msgs else None + + async def _send_partial_update(self, patches: JsonPatch, message_id: str | None = None): + if not (ext := StreamingExtensionServer.current()): + return + await self.task_updater.update_status( + state=TaskState.TASK_STATE_WORKING, metadata=ext.to_metadata(patches, message_id=message_id) + ) - async def _run_agent_function(self, initial_message: Message) -> None: + async def _handle_message_yield(self, yielded_value: RunYield) -> RunYieldResume: + result = self._accumulator.process(yielded_value) + if result.accumulated: + if result.patch: + await self._send_partial_update(result.patch, message_id=result.message_id) + return None + return await self._dispatch_control_yield(yielded_value, result.draft) + + async def _dispatch_control_yield( + self, yielded_value: RunYield, draft: Message | None = None + ) -> RunYieldResume: + match yielded_value: + case Message() as message: + await self.task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=self._prepare_message(message, draft), + ) + case TaskStatus( + state=(TaskState.TASK_STATE_AUTH_REQUIRED | TaskState.TASK_STATE_INPUT_REQUIRED) as state, + message=message, + ): + await self.task_updater.update_status( + state=state, + message=self._prepare_message(message, draft), + ) + self._working = False + resume_value = await self.resume_queue.get() + self.resume_queue.task_done() + return resume_value + case TaskStatus(state=state, message=message): + await self.task_updater.update_status( + state=state, + message=self._prepare_message(message, draft), + ) + case TaskStatusUpdateEvent( + status=TaskStatus(state=state, message=message), + metadata=metadata, + ): + await self.task_updater.update_status( + state=state, + message=self._prepare_message(message, draft), + metadata=dict(metadata), + ) + + async def _agent_loop(self, task: asyncio.Task): yield_queue = self.run_context._yield_queue yield_resume_queue = self.run_context._yield_resume_queue + resume_value: RunYieldResume | Exception = None + opened_artifacts: set[str] = set() + + while not task.done() or yield_queue.async_q.qsize() > 0: + yielded_value = await yield_queue.async_q.get() + resume_value = None + self.last_invocation = datetime.now() + + if isinstance(yielded_value, _message.Message): + validate_message(yielded_value) + + try: + match yielded_value: + case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata): + last_chunk = True + if "_last_chunk" in metadata: + last_chunk = bool(metadata["_last_chunk"]) + del metadata["_last_chunk"] + append = artifact_id in opened_artifacts + if not last_chunk: + opened_artifacts.add(artifact_id) + elif artifact_id in opened_artifacts: + opened_artifacts.remove(artifact_id) + + await self.task_updater.add_artifact( + parts=list(parts), + artifact_id=artifact_id, + name=name, + metadata=dict(metadata), + last_chunk=last_chunk, + append=append, + ) + + case TaskArtifactUpdateEvent( + artifact=Artifact(artifact_id=artifact_id, name=name, metadata=metadata, parts=parts), + append=append, + last_chunk=last_chunk, + ): + await self.task_updater.add_artifact( + parts=list(parts), + artifact_id=artifact_id, + name=name, + metadata=dict(metadata), + append=append, + last_chunk=last_chunk, + ) + case Part() | dict() | Metadata() | str() | TaskStatus() | TaskStatusUpdateEvent() | Message(): + resume_value = await self._handle_message_yield(yielded_value) + case _: + raise InvalidYieldError(yielded_value) + except Exception as e: + resume_value = e + await yield_resume_queue.async_q.put(resume_value) + + async def _run_agent_function(self, initial_message: Message) -> None: + task: asyncio.Task | None = None try: async with self._agent.dependency_container( initial_message, self.run_context, self.request_context @@ -459,118 +576,22 @@ async def _run_agent_function(self, initial_message: Message) -> None: self._agent.execute_fn(self.run_context, **dependency_container.user_dependency_args) ) try: - resume_value: RunYieldResume = None - opened_artifacts: set[str] = set() - while not task.done() or yield_queue.async_q.qsize() > 0: - yielded_value = await yield_queue.async_q.get() - - if isinstance(yielded_value, _message.Message): - validate_message(yielded_value) - - self.last_invocation = datetime.now() - - match yielded_value: - case str(text): - await self.task_updater.update_status( - TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message(parts=[Part(text=text)]), - ) - case Part() as part: - await self.task_updater.update_status( - TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message(parts=[part]), - ) - case Message() as message: - await self.task_updater.update_status( - TaskState.TASK_STATE_WORKING, message=self._with_context(message) - ) - case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata): - last_chunk = True - if "_last_chunk" in metadata: - last_chunk = bool(metadata["_last_chunk"]) - del metadata["_last_chunk"] - append = artifact_id in opened_artifacts - if not last_chunk: - opened_artifacts.add(artifact_id) - elif artifact_id in opened_artifacts: - opened_artifacts.remove(artifact_id) - - await self.task_updater.add_artifact( - parts=list(parts), - artifact_id=artifact_id, - name=name, - metadata=dict(metadata), - last_chunk=last_chunk, - append=append, - ) - case TaskStatus( - state=( - TaskState.TASK_STATE_AUTH_REQUIRED | TaskState.TASK_STATE_INPUT_REQUIRED - ) as state, - message=message, - ): - await self.task_updater.update_status(state=state, message=self._with_context(message)) - self._working = False - resume_value = await self.resume_queue.get() - self.resume_queue.task_done() - case TaskStatus(state=state, message=message): - await self.task_updater.update_status(state=state, message=self._with_context(message)) - case TaskStatusUpdateEvent( - status=TaskStatus(state=state, message=message), - metadata=metadata, - ): - await self.task_updater.update_status( - state=state, message=self._with_context(message), metadata=dict(metadata) - ) - case TaskArtifactUpdateEvent( - artifact=Artifact(artifact_id=artifact_id, name=name, metadata=metadata, parts=parts), - append=append, - last_chunk=last_chunk, - ): - await self.task_updater.add_artifact( - parts=list(parts), - artifact_id=artifact_id, - name=name, - metadata=dict(metadata), - append=append, - last_chunk=last_chunk, - ) - case Metadata() as metadata: - await self.task_updater.update_status( - state=TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message(parts=[], metadata=metadata), - ) - case dict() as data: - from google.protobuf.struct_pb2 import Struct, Value - - s = Struct() - s.update(data) - await self.task_updater.update_status( - state=TaskState.TASK_STATE_WORKING, - message=self.task_updater.new_agent_message( - parts=[Part(data=Value(struct_value=s))] - ), - ) - case Exception() as ex: - raise ex - case _: - raise ValueError(f"Invalid value yielded from agent: {type(yielded_value)}") - - await yield_resume_queue.async_q.put(resume_value) - - await self.task_updater.complete() - - except (janus.AsyncQueueShutDown, GeneratorExit): - await self.task_updater.complete() + with suppress(janus.AsyncQueueShutDown, GeneratorExit): + await self._agent_loop(task) + await task + final_message = self._accumulator.flush() + await self.task_updater.complete(message=self._prepare_message(final_message)) except Exception as ex: logger.error("Error when executing agent", exc_info=ex) await self.task_updater.failed(get_error_extension_context().server.message(ex)) - await cancel_task(task) except Exception as ex: logger.error("Error when executing agent", exc_info=ex) await self.task_updater.failed(get_error_extension_context().server.message(ex)) finally: self._working = False + if task: + with suppress(Exception): + await cancel_task(task) with suppress(Exception): self._handle_finish() diff --git a/apps/adk-py/src/kagenti_adk/server/constants.py b/apps/adk-py/src/kagenti_adk/server/constants.py index 0fec9ad9..b4fd32e5 100644 --- a/apps/adk-py/src/kagenti_adk/server/constants.py +++ b/apps/adk-py/src/kagenti_adk/server/constants.py @@ -8,9 +8,11 @@ from kagenti_adk.a2a.extensions import BaseExtensionServer from kagenti_adk.a2a.extensions.services.platform import PlatformApiExtensionServer, PlatformApiExtensionSpec +from kagenti_adk.a2a.extensions.streaming import StreamingExtensionServer, StreamingExtensionSpec from kagenti_adk.a2a.extensions.ui.error import ErrorExtensionParams, ErrorExtensionServer, ErrorExtensionSpec DEFAULT_IMPLICIT_EXTENSIONS: Final[dict[str, BaseExtensionServer]] = { + StreamingExtensionSpec.URI: StreamingExtensionServer(StreamingExtensionSpec()), ErrorExtensionSpec.URI: ErrorExtensionServer(ErrorExtensionSpec(ErrorExtensionParams())), PlatformApiExtensionSpec.URI: PlatformApiExtensionServer(PlatformApiExtensionSpec()), } diff --git a/apps/adk-py/src/kagenti_adk/server/context.py b/apps/adk-py/src/kagenti_adk/server/context.py index cb20cdb4..01e85eb2 100644 --- a/apps/adk-py/src/kagenti_adk/server/context.py +++ b/apps/adk-py/src/kagenti_adk/server/context.py @@ -8,7 +8,12 @@ from uuid import UUID import janus -from a2a.types import Artifact, Message, Task +from a2a.types import ( + Artifact, + Message, + Task, +) +from asgiref.sync import async_to_sync from pydantic import BaseModel, PrivateAttr from kagenti_adk.a2a.types import RunYield, RunYieldResume @@ -16,22 +21,26 @@ from kagenti_adk.server.store.context_store import ContextStoreInstance +class RunContextSettings(BaseModel): + strict: bool = False + + class RunContext(BaseModel, arbitrary_types_allowed=True): task_id: str context_id: str current_task: Task | None = None related_tasks: list[Task] | None = None + strict: bool = False # TODO: explain strict mode - what yields will stop message etc. Use in match/case _store: ContextStoreInstance _yield_queue: janus.Queue[RunYield] = PrivateAttr(default_factory=janus.Queue) + _yield_resume_queue: janus.Queue[RunYieldResume | Exception] = PrivateAttr(default_factory=janus.Queue) def __init__(self, _store: ContextStoreInstance, **data): super().__init__(**data) self._store = _store - _yield_resume_queue: janus.Queue[RunYieldResume] = PrivateAttr(default_factory=janus.Queue) - - async def store(self, data: Message | Artifact): + def _prepare_store_data(self, data: Message | Artifact) -> Message | Artifact: if not self._store: raise RuntimeError("Context store is not initialized") if isinstance(data, Message): @@ -39,8 +48,14 @@ async def store(self, data: Message | Artifact): msg.CopyFrom(data) msg.context_id = self.context_id msg.task_id = self.task_id - data = msg - await self._store.store(data) + return msg + return data + + async def store(self, data: Message | Artifact): + await self._store.store(self._prepare_store_data(data)) + + def store_sync(self, data: Message | Artifact): + async_to_sync(self._store.store)(self._prepare_store_data(data)) @overload def load_history(self, load_history_items: Literal[False] = False) -> AsyncGenerator[Message | Artifact, None]: ... @@ -63,11 +78,17 @@ async def delete_history_from_id(self, from_id: UUID) -> None: def yield_sync(self, value: RunYield) -> RunYieldResume: self._yield_queue.sync_q.put(value) - return self._yield_resume_queue.sync_q.get() + resp = self._yield_resume_queue.sync_q.get() + if isinstance(resp, Exception): + raise resp + return resp async def yield_async(self, value: RunYield) -> RunYieldResume: await self._yield_queue.async_q.put(value) - return await self._yield_resume_queue.async_q.get() + resp = await self._yield_resume_queue.async_q.get() + if isinstance(resp, Exception): + raise resp + return resp def shutdown(self) -> None: self._yield_queue.shutdown() diff --git a/apps/adk-py/src/kagenti_adk/server/dependencies.py b/apps/adk-py/src/kagenti_adk/server/dependencies.py index d02a82af..32ece69b 100644 --- a/apps/adk-py/src/kagenti_adk/server/dependencies.py +++ b/apps/adk-py/src/kagenti_adk/server/dependencies.py @@ -17,7 +17,7 @@ from typing_extensions import Doc from kagenti_adk.a2a.extensions.base import BaseExtensionServer, BaseExtensionSpec -from kagenti_adk.server.context import RunContext +from kagenti_adk.server.context import RunContext, RunContextSettings Dependency: TypeAlias = Callable[[Message, RunContext, RequestContext], Any] | BaseExtensionServer[Any, Any] @@ -58,9 +58,36 @@ async def lifespan() -> AsyncIterator[Dependency]: return lifespan() +def _get_param_type_hints(fn: Callable[..., Any]) -> dict[str, Any]: + """Get type hints for function parameters only, skipping the return annotation. + + typing.get_type_hints() evaluates all annotations including return type, + which can fail when annotations use `X | Y` with types that don't support + the `|` operator at runtime (e.g. protobuf classes, factory functions). + """ + try: + return typing.get_type_hints(fn, include_extras=True) + except TypeError: + # Evaluate parameter annotations individually, skipping any that fail + globalns = getattr(fn, "__globals__", {}) + hints: dict[str, Any] = {} + for name, param in inspect.signature(fn).parameters.items(): + ann = param.annotation + if ann is inspect.Parameter.empty: + continue + if isinstance(ann, str): + try: + hints[name] = eval(ann, globalns) # noqa: S307 + except Exception: + hints[name] = ann + else: + hints[name] = ann + return hints + + def extract_dependencies(fn: Callable[..., Any]) -> dict[str, Depends]: sign = inspect.signature(fn) - type_hints = typing.get_type_hints(fn, include_extras=True) + type_hints = _get_param_type_hints(fn) dependencies = {} seen_keys = set() @@ -70,6 +97,11 @@ def process_args(name: str, args: tuple[Any, ...]) -> None: # extension_param: Annotated[some_type, Depends(some_callable)] if isinstance(spec, Depends): dependencies[name] = spec + # extension_param: Annotated[RunContext, RunContextSettings()] + if isinstance(dep_type, RunContext) and isinstance(spec, RunContextSettings): + dependencies[name] = Depends( + lambda _message, run_context, _request_context: run_context.model_copy(update=spec.model_dump()) + ) # extension_param: Annotated[BaseExtensionServer, BaseExtensionSpec()] elif ( isclass(dep_type) and issubclass(dep_type, BaseExtensionServer) and isinstance(spec, BaseExtensionSpec) @@ -114,7 +146,7 @@ def process_args(name: str, args: tuple[Any, ...]) -> None: if reserved_names := {param for param in dependencies if param.startswith("__")}: raise TypeError(f"User-defined dependencies cannot start with double underscore: {reserved_names}") - extension_deps = Counter(dep.extension.spec.URI for dep in dependencies.values() if dep.extension) + extension_deps = Counter(dep.extension.spec.URI for dep in dependencies.values() if dep.extension is not None) if duplicate_uris := {k for k, v in extension_deps.items() if v > 1}: raise TypeError(f"Duplicate extension URIs found in the agent function: {duplicate_uris}") diff --git a/apps/adk-py/src/kagenti_adk/server/exceptions.py b/apps/adk-py/src/kagenti_adk/server/exceptions.py index 7200f7b6..68b91907 100644 --- a/apps/adk-py/src/kagenti_adk/server/exceptions.py +++ b/apps/adk-py/src/kagenti_adk/server/exceptions.py @@ -1,5 +1,10 @@ # Copyright 2026 © IBM Corp. # SPDX-License-Identifier: Apache-2.0 - from __future__ import annotations +from kagenti_adk.a2a.types import RunYield + + +class InvalidYieldError(RuntimeError): + def __init__(self, yielded_value: RunYield): + super().__init__(f"Invalid yield of type: {type(yielded_value)}") diff --git a/apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py b/apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py new file mode 100644 index 00000000..7b2dca9d --- /dev/null +++ b/apps/adk-py/src/kagenti_adk/server/jsonpatch_ext.py @@ -0,0 +1,160 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import difflib +import json +from collections.abc import MutableMapping, MutableSequence +from typing import Any + +import jsonpatch +from jsonpatch import DiffBuilder, JsonPatch, PatchOperation + + +class StrInsOperation(PatchOperation): + """ + Inserts text into a string property at a specific position. + Operation format: { "op": "str_ins", "path": "/foo/bar", "pos": 5, "value": "text" } + If "pos" is omitted, it defaults to appending. + """ + + def apply(self, obj: Any) -> Any: + try: + value = self.operation["value"] + except KeyError: + raise jsonpatch.InvalidJsonPatch("The operation does not contain a 'value' member") + + subobj, part = self.pointer.to_last(obj) + + if isinstance(subobj, MutableMapping): + if part not in subobj: + raise jsonpatch.JsonPatchConflict(f"Target path {self.location} does not exist") + current_val = subobj[part] + elif isinstance(subobj, MutableSequence): + try: + part_idx = int(part) # type: ignore [arg-type] + current_val = subobj[part_idx] + except (IndexError, ValueError): + raise jsonpatch.JsonPatchConflict(f"Target path {self.location} does not exist") + else: + raise jsonpatch.JsonPatchConflict(f"Cannot apply str_ins to {type(subobj)}") + + if not isinstance(current_val, str): + raise jsonpatch.JsonPatchConflict(f"Target value at {self.location} is not a string") + + pos = self.operation.get("pos", len(current_val)) + if not isinstance(pos, (int, float)): + raise jsonpatch.InvalidJsonPatch("The operation 'pos' member must be a number") + pos = int(pos) + + if pos < 0 or pos > len(current_val): + raise jsonpatch.JsonPatchConflict( + f"Position {pos} is out of bounds for string of length {len(current_val)}" + ) + + # Insert logic: string slicing + new_val = current_val[:pos] + value + current_val[pos:] + + if isinstance(subobj, MutableMapping): + subobj[part] = new_val + elif isinstance(subobj, MutableSequence): + subobj[int(part)] = new_val # type: ignore [arg-type] + + return obj + + def to_string(self) -> str: + # We need to include 'pos' in the string representation + op_dict = {"op": "str_ins", "path": self.location, "value": self.operation["value"]} + if "pos" in self.operation: + op_dict["pos"] = self.operation["pos"] + return json.dumps(op_dict) + + +class ExtendedDiffBuilder(DiffBuilder): + """ + Extended DiffBuilder that detects string insertions and generates `str_ins` operations. + It uses difflib to find the most efficient string patch (currently focusing on single continuous insertion). + """ + + def _item_replaced(self, path: str, key: Any, item: Any) -> None: + """ + Called when a value is replaced. We check if it's a string modification + that can be represented as a str_ins operation. + """ + # Attempt to retrieve old value using the pointer + # path is the parent path, key is the item key/index + # We construct the full pointer to the item + full_path_str = path + "/" + str(key).replace("~", "~0").replace("/", "~1") + ptr = self.pointer_cls(full_path_str) + + try: + old_value = ptr.resolve(self.src_doc) + except Exception: + # Fallback to standard replace if we can't find old value (shouldn't happen) + super()._item_replaced(path, key, item) + return + + if isinstance(old_value, str) and isinstance(item, str): + # Optimization: Check for simple append first (O(1)ish vs O(N)) + if item.startswith(old_value) and len(item) > len(old_value): + diff = item[len(old_value) :] + self.insert( + StrInsOperation( + {"op": "str_ins", "path": full_path_str, "pos": len(old_value), "value": diff}, + pointer_cls=self.pointer_cls, + ) + ) + return + + # Analyze for arbitrary insertion using difflib + # We look for exactly ONE 'insert' block and everything else 'equal'. + # If there are deletes or replaces or multiple inserts, we fallback to full value replace + # because str_ins only does insertion, not deletion/replacement. + + matcher = difflib.SequenceMatcher(None, old_value, item) + opcodes = matcher.get_opcodes() + + insert_ops = [op for op in opcodes if op[0] == "insert"] + other_ops = [op for op in opcodes if op[0] != "insert"] + + # Condition: + # 1. Must have at least one insert (otherwise equal or delete) + # 2. All other ops must be 'equal' + # 3. Ideally only ONE insert op to keep patch simple (though we could support multiple) + # For streaming, we typically expect one chunk inserted. + + if len(insert_ops) == 1 and all(op[0] == "equal" for op in other_ops): + tag, i1, i2, j1, j2 = insert_ops[0] + inserted_text = item[j1:j2] + + # 'pos' is i1 (index in old string where insertion starts) + + self.insert( + StrInsOperation( + {"op": "str_ins", "path": full_path_str, "pos": i1, "value": inserted_text}, + pointer_cls=self.pointer_cls, + ) + ) + return + + super()._item_replaced(path, key, item) + + +class ExtendedJsonPatch(JsonPatch): + operations = dict(JsonPatch.operations) + operations["str_ins"] = StrInsOperation # type: ignore [assignment] + + +def make_patch(src: Any, dst: Any) -> ExtendedJsonPatch: + """ + Generates a patch using the ExtendedDiffBuilder. + """ + builder = ExtendedDiffBuilder(src, dst, jsonpatch.json.dumps, jsonpatch.JsonPointer) + builder._compare_values("", None, src, dst) + ops = list(builder.execute()) + # Note: jsonpatch.JsonPatch(ops) validates ops but might re-parse if passed as list of dicts? + # Actually JsonPatch init takes list of dicts or list of PatchOperations? + # It takes list of dicts usually. `builder.execute()` yields dicts (operation dicts). + # Wait, `DiffBuilder.execute` yields `PatchOperation.operation` (which is a dict). + # So `ops` is a list of dicts. + return ExtendedJsonPatch(ops) diff --git a/apps/adk-py/src/kagenti_adk/server/utils.py b/apps/adk-py/src/kagenti_adk/server/utils.py index 244f9f4c..08abc617 100644 --- a/apps/adk-py/src/kagenti_adk/server/utils.py +++ b/apps/adk-py/src/kagenti_adk/server/utils.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations - +from kagenti_adk.types import JsonValue +from kagenti_adk.a2a.types import Metadata +from a2a.types import Message import asyncio from asyncio import CancelledError from contextlib import suppress @@ -26,3 +28,38 @@ async def close_queue(queue_manager: QueueManager, queue_name: str, immediate: b if queue := await queue_manager.get(queue_name): await queue.close(immediate=immediate) await queue_manager.close(queue_name) + + +def _merge_recursive(obj: JsonValue, other: JsonValue) -> JsonValue: + if isinstance(obj, dict) and isinstance(other, dict): + merged = {**obj} + for k, v in other.items(): + if k in merged: + merged[k] = _merge_recursive(merged[k], v) + else: + merged[k] = v + return merged + elif isinstance(obj, list) and isinstance(other, list): + return obj + other + else: + return other + + +def merge_metadata(*metadata_items: Metadata) -> Metadata: + result = Metadata() + for m in metadata_items: + for k, v in m.items(): + result[k] = _merge_recursive(result.get(k, {}), v) + return result + + +def merge_messages(*messages: Message) -> Message | None: + if not messages: + return None + merged = Message() + merged.CopyFrom(messages[0]) + for msg in messages[1:]: + merged.parts.extend(msg.parts) + for k, v in msg.metadata.items(): + merged.metadata[k] = v + return merged diff --git a/apps/adk-py/src/kagenti_adk/types.py b/apps/adk-py/src/kagenti_adk/types.py index 16777f21..49d54616 100644 --- a/apps/adk-py/src/kagenti_adk/types.py +++ b/apps/adk-py/src/kagenti_adk/types.py @@ -4,13 +4,15 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, TypeAlias, TypedDict +from typing import TYPE_CHECKING, Required, TypeAlias, TypedDict from a2a.types import SecurityRequirement, SecurityScheme from starlette.authentication import AuthenticationBackend __all__ = [ "JsonDict", + "JsonPatch", + "JsonPatchOp", "JsonValue", "SdkAuthenticationBackend", ] @@ -19,13 +21,24 @@ JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None JsonDict: TypeAlias = dict[str, JsonValue] else: - from typing import Union + from typing import Union # noqa: F401 from typing_extensions import TypeAliasType - JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007 + JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]") +class JsonPatchOp(TypedDict, total=False): + """A single JSON Patch operation (RFC 6902), extended with 'str_ins' from json-crdt-patch.""" + + op: Required[str] + path: Required[str] + value: JsonValue + pos: int # str_ins extension: insertion position + + +JsonPatch: TypeAlias = list[JsonPatchOp] + class A2ASecurity(TypedDict): security_requirements: list[SecurityRequirement] diff --git a/apps/adk-py/tests/conftest.py b/apps/adk-py/tests/conftest.py deleted file mode 100644 index c2a8dadc..00000000 --- a/apps/adk-py/tests/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import async_lru - -# Disable async_lru event loop check in tests -async_lru._LRUCacheWrapper._check_loop = lambda self, loop: None diff --git a/apps/adk-py/tests/e2e/conftest.py b/apps/adk-py/tests/e2e/conftest.py index fe29e31f..9b0d998d 100644 --- a/apps/adk-py/tests/e2e/conftest.py +++ b/apps/adk-py/tests/e2e/conftest.py @@ -12,7 +12,8 @@ import httpx import pytest -from a2a.client import Client, ClientConfig, ClientFactory +from a2a.client import Client, ClientCallContext, ClientConfig, ClientFactory +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import ( AgentCard, Artifact, @@ -26,7 +27,7 @@ from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed from kagenti_adk.a2a.extensions.ui.agent_detail import AgentDetail -from kagenti_adk.a2a.types import AgentArtifact, ArtifactChunk, InputRequired, RunYield, RunYieldResume +from kagenti_adk.a2a.types import AgentArtifact, AgentMessage, ArtifactChunk, InputRequired, RunYield, RunYieldResume from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext from kagenti_adk.server.store.context_store import ContextStore @@ -40,9 +41,19 @@ def get_free_port() -> int: return int(sock.getsockname()[1]) +def make_extension_context(extensions: list[str] | None = None) -> ClientCallContext | None: + """Create a ClientCallContext with extension URIs as service parameters.""" + if not extensions: + return None + return ClientCallContext(service_parameters={HTTP_EXTENSION_HEADER: ",".join(extensions)}) + + @asynccontextmanager async def run_server( - server: Server, port: int, context_store: ContextStore | None = None, task_timeout: timedelta | None = None + server: Server, + port: int, + context_store: ContextStore | None = None, + task_timeout: timedelta | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: tg.create_task( @@ -66,7 +77,9 @@ async def run_server( card_resp = await httpx_client.get(f"{base_url}{AGENT_CARD_WELL_KNOWN_PATH}") card_resp.raise_for_status() card = ParseDict(card_resp.json(), AgentCard(), ignore_unknown_fields=True) - client = ClientFactory(ClientConfig(httpx_client=httpx_client)).create(card=card) + client = ClientFactory(ClientConfig(httpx_client=httpx_client)).create( + card=card, + ) yield server, client finally: server.should_exit = True @@ -78,7 +91,9 @@ def create_server_with_agent(): @asynccontextmanager async def _create_server( - agent_fn, context_store: ContextStore | None = None, task_timeout: timedelta | None = None + agent_fn, + context_store: ContextStore | None = None, + task_timeout: timedelta | None = None, ) -> AsyncIterator[tuple[Server, Client]]: server = Server() server.agent(detail=AgentDetail(interaction_mode="multi-turn"))(agent_fn) @@ -95,7 +110,7 @@ async def _create_server( @pytest.fixture async def echo(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def echo(message: Message, context: RunContext) -> AsyncGenerator[str, Message]: + async def echo(message: Message, context: RunContext) -> AsyncGenerator[AgentMessage, Message]: for part in message.parts: yield part.text @@ -105,7 +120,7 @@ async def echo(message: Message, context: RunContext) -> AsyncGenerator[str, Mes @pytest.fixture async def slow_echo(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def slow_echo(message: Message, context: RunContext) -> AsyncGenerator[str, Message]: + async def slow_echo(message: Message, context: RunContext) -> AsyncGenerator[AgentMessage, Message]: # Slower version with delay for part in message.parts: await asyncio.sleep(1) @@ -117,9 +132,9 @@ async def slow_echo(message: Message, context: RunContext) -> AsyncGenerator[str @pytest.fixture async def awaiter(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | str, Message]: + async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | AgentMessage, Message]: # Agent that requires input - yield "Processing initial message..." + yield AgentMessage(text="Processing initial message...") resume_message = yield InputRequired(text="need input") yield f"Received resume: {resume_message.parts[0].text if resume_message.parts else 'empty'}" @@ -130,9 +145,9 @@ async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskS @pytest.fixture async def awaiter_with_1s_timeout(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | str, Message]: + async def awaiter(message: Message, context: RunContext) -> AsyncGenerator[TaskStatus | AgentMessage, Message]: # Agent that requires input - yield "Processing initial message..." + yield AgentMessage(text="Processing initial message...") resume_message = yield InputRequired(text="need input") yield f"Received resume: {resume_message.parts[0].text if resume_message.parts else 'empty'}" @@ -153,7 +168,7 @@ async def failer(message: Message, context: RunContext) -> AsyncGenerator[RunYie @pytest.fixture async def raiser(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def raiser(message: Message, context: RunContext) -> AsyncGenerator[str, Message]: + async def raiser(message: Message, context: RunContext) -> AsyncGenerator[AgentMessage, Message]: # Another failing agent raise RuntimeError("Wrong question buddy!") @@ -163,9 +178,11 @@ async def raiser(message: Message, context: RunContext) -> AsyncGenerator[str, M @pytest.fixture async def artifact_producer(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - async def artifact_producer(message: Message, context: RunContext) -> AsyncGenerator[str | Artifact, Message]: + async def artifact_producer( + message: Message, context: RunContext + ) -> AsyncGenerator[AgentMessage | Artifact, Message]: # Agent producing artifacts - yield "Processing with artifacts" + yield AgentMessage(text="Processing with artifacts") # Create artifacts with proper parts structure yield AgentArtifact( @@ -195,23 +212,24 @@ async def chunked_artifact_producer( message: Message, context: RunContext ) -> AsyncGenerator[str | Artifact, Message]: # Agent producing chunked artifacts - yield "Processing chunked artifacts" + yield AgentMessage(text="Processing chunked artifacts") - # Create a large text artifact in chunks + # Create a large text artifact in chunks using ArtifactChunk with shared artifact_id + shared_id = "chunked-artifact-1" yield ArtifactChunk( - artifact_id="1", + artifact_id=shared_id, name="large-file.txt", parts=[Part(text="This is the first chunk of data.\n")], ) yield ArtifactChunk( - artifact_id="1", + artifact_id=shared_id, name="large-file.txt", parts=[Part(text="This is the second chunk of data.\n")], ) yield ArtifactChunk( - artifact_id="1", + artifact_id=shared_id, name="large-file.txt", parts=[Part(text="This is the final chunk of data.\n")], last_chunk=True, diff --git a/apps/adk-py/tests/e2e/test_streaming.py b/apps/adk-py/tests/e2e/test_streaming.py new file mode 100644 index 00000000..118efaeb --- /dev/null +++ b/apps/adk-py/tests/e2e/test_streaming.py @@ -0,0 +1,476 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import AsyncGenerator, AsyncIterator + +import pytest +from a2a.client import Client +from a2a.client.helpers import create_text_message_object +from a2a.types import ( + Artifact, + Message, + Part, + Role, + SendMessageRequest, + StreamResponse, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from google.protobuf.json_format import MessageToDict + +from kagenti_adk.a2a.extensions.streaming import ( + ArtifactDelta, + MetadataDelta, + PartDelta, + StateChange, + StreamingExtensionClient, + StreamingExtensionSpec, + TextDelta, +) +from kagenti_adk.a2a.types import AgentMessage, ArtifactChunk, InputRequired, Metadata, RunYield +from kagenti_adk.server import Server +from kagenti_adk.server.context import RunContext +from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch +from conftest import make_extension_context + +pytestmark = pytest.mark.e2e + +STREAMING_URI = StreamingExtensionSpec.URI +STREAMING_CONTEXT = make_extension_context([STREAMING_URI]) + + +def extract_streaming_patches(events: list) -> list[dict]: + """Extract streaming patches from collected client events (flattened).""" + patches = [] + for event in events: + match event: + case (StreamResponse(status_update=TaskStatusUpdateEvent(metadata=metadata)), _) if metadata: + meta_dict = MessageToDict(metadata) + if STREAMING_URI in meta_dict: + patch_list = meta_dict[STREAMING_URI].get("message_update") + if isinstance(patch_list, list): + patches.extend(patch_list) + return patches + + +def apply_patches(patches: list[dict]) -> dict: + """Apply a sequence of streaming patches to build a message object.""" + return ExtendedJsonPatch(patches).apply({}) + + +def extract_status_events(events: list) -> list[TaskStatusUpdateEvent]: + """Extract all status update events.""" + status_events = [] + for event in events: + match event: + case (StreamResponse(status_update=TaskStatusUpdateEvent() as update), _): + if MessageToDict(update.status): + status_events.append(update) + return status_events + + +# --- Fixtures --- + + +@pytest.fixture +async def streaming_string_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + async def string_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield "Hello" + yield " beautiful" + yield " world" + + async with create_server_with_agent( + string_yielder + ) as (server, client): + yield server, client + + +@pytest.fixture +async def streaming_part_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + async def part_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield Part(text="explicit part") + + async with create_server_with_agent( + part_yielder + ) as (server, client): + yield server, client + + +@pytest.fixture +async def streaming_mixed_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + async def mixed_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield "text1" + yield "text2" + yield AgentMessage(text="final") + + async with create_server_with_agent( + mixed_yielder + ) as (server, client): + yield server, client + + +@pytest.fixture +async def no_streaming_string_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + """Same agent as streaming_string_agent but WITHOUT streaming extension.""" + + async def string_yielder(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + yield "Hello" + yield " beautiful" + yield " world" + + async with create_server_with_agent(string_yielder) as (server, client): + yield server, client + + +# --- Tests --- + + +async def test_string_yields_produce_streaming_patches(streaming_string_agent): + """Verify applying streaming patches builds the same message as the wire.""" + _, client = streaming_string_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + patches = extract_streaming_patches(events) + assert len(patches) >= 2 + + # Apply all patches to build the message + built = apply_patches(patches) + assert built["parts"][0]["text"] == "Hello beautiful world" + + # Compare to the COMPLETED wire message + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + wire_parts = [MessageToDict(p) for p in completed[0].status.message.parts] + assert wire_parts == built["parts"] + + +async def test_part_yield_produces_streaming_patch(streaming_part_agent): + """Verify applying streaming patches for Part yields builds the correct message.""" + _, client = streaming_part_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + patches = extract_streaming_patches(events) + assert len(patches) >= 1 + + built = apply_patches(patches) + assert built["parts"][0]["text"] == "explicit part" + + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + wire_parts = [MessageToDict(p) for p in completed[0].status.message.parts] + assert wire_parts == built["parts"] + + +async def test_completion_flushes_accumulated_message(streaming_string_agent): + """Verify that the completed event contains all accumulated text.""" + _, client = streaming_string_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + + final_message = completed[0].status.message + assert final_message.parts[0].text == "Hello beautiful world" + + +async def test_mixed_yields_message_flush(streaming_mixed_agent): + """Verify accumulated text is flushed before explicit AgentMessage.""" + _, client = streaming_mixed_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT): + events.append(event) + + # Extract all messages from WORKING status events + working_messages = [] + for event in events: + match event: + case (StreamResponse(status_update=TaskStatusUpdateEvent(status=TaskStatus( + state=TaskState.TASK_STATE_WORKING, message=Message(message_id=mid) + ))), _) if mid: + working_messages.append(event[0].status_update.status.message) + + # Should have at least the explicit AgentMessage as a WORKING event + # The accumulated "text1text2" is flushed as a draft merged into the AgentMessage + assert len(working_messages) >= 1 + + # Verify the completed event contains both messages in history + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + + +async def test_no_streaming_patches_without_extension(no_streaming_string_agent): + """Verify no streaming metadata when extension is not activated by client.""" + _, client = no_streaming_string_agent + events = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hi"))): + events.append(event) + + patches = extract_streaming_patches(events) + assert len(patches) == 0 + + # But the final message should still have the accumulated text + status_events = extract_status_events(events) + completed = [e for e in status_events if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + assert completed[0].status.message.parts[0].text == "Hello beautiful world" + + +async def test_complex_accumulator_state_machine(create_server_with_agent): + """Comprehensive test exercising all accumulator state transitions in a single scenario. + + Agent yield sequence and expected state machine transitions: + + yield "Hello" Base → TextPart (add /parts/-) + yield " world" TextPart → TextPart (str_ins) + yield Part("[separator]") TextPart → Message (flush text part, add explicit part) + yield {"score": 42} Message → Message (dict→Part, add /parts/-) + yield Metadata(ext: [a]) Message → Message (add /metadata) + yield Metadata(ext: [b]) Message → Message (replace /metadata, arrays concatenated) + yield AgentMessage("ckpt") Message → Base (flush draft, dispatch WORKING message) + yield ArtifactChunk(...) handled outside accumulator (artifact event) + yield "post-artifact" Base → TextPart (new accumulation after reset) + yield InputRequired(...) TextPart → Base (flush draft, INPUT_REQUIRED) + --- resume --- + yield "you said: ..." Base → TextPart (new accumulation) + implicit flush → COMPLETED + """ + + async def complex_agent(message: Message, context: RunContext) -> AsyncIterator[RunYield]: + # Phase 1: Text streaming (Base → TextPart → TextPart) + yield "Hello" + yield " world" + + # Phase 2: Part after text (TextPart → Message; flushes text, adds explicit part) + yield Part(text="[separator]") + + # Phase 3: Dict yield (Message → Message; dict converted to data Part) + yield {"score": 42} + + # Phase 4: Metadata accumulation with array concatenation + yield Metadata({"ext://test": [{"ref": "a"}]}) + yield Metadata({"ext://test": [{"ref": "b"}]}) + + # Phase 5: Explicit Message flushes everything accumulated so far + # Draft: parts=[Text("Hello world"), Part("[separator]"), DataPart({score:42})], + # metadata={ext://test: [{ref:a},{ref:b}]} + # Merged with AgentMessage: draft parts + [Part("checkpoint")] + yield AgentMessage(text="checkpoint") + + # Phase 6: Artifact (bypasses accumulator entirely) + yield ArtifactChunk( + artifact_id="art-1", name="data.txt", + parts=[Part(text="artifact body")], last_chunk=True, + ) + + # Phase 7: New accumulation cycle after full reset + yield "post-artifact" + + # Phase 8: InputRequired flushes accumulated text, pauses for user input + # Draft: parts=[Text("post-artifact")] + # Merged with InputRequired message: [Text("post-artifact"), Text("what next?")] + resume_msg = yield InputRequired(text="what next?") + + # Phase 9: After resume — new accumulation, then implicit flush on return + yield f"you said: {resume_msg.parts[0].text}" + + async with create_server_with_agent(complex_agent) as (_, client): + # --- First send: initial message --- + events_1 = [] + async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="go")), context=STREAMING_CONTEXT): + events_1.append(event) + + all_patches = extract_streaming_patches(events_1) + status_events = extract_status_events(events_1) + + # Split patches into accumulation cycles by root replace boundaries + cycles: list[list[dict]] = [] + for patch in all_patches: + if patch["op"] == "replace" and patch["path"] == "": + cycles.append([]) + cycles[-1].append(patch) + assert len(cycles) == 2 # cycle 1: before AgentMessage, cycle 2: after artifact + + # -- Cycle 1: Apply patches to build the draft -- + # Yields: "Hello", " world", Part("[separator]"), {"score": 42}, Metadata x2 + draft_1 = apply_patches(cycles[0]) + assert draft_1["parts"][0]["text"] == "Hello world" + assert draft_1["parts"][1]["text"] == "[separator]" + assert draft_1["parts"][2]["data"] == {"score": 42.0} + assert draft_1["metadata"]["ext://test"] == [{"ref": "a"}, {"ref": "b"}] + + # The WORKING wire message = merge(draft_1, AgentMessage("checkpoint")) + # Draft parts are a prefix of the wire message parts + working = [ + e for e in status_events + if e.status.state == TaskState.TASK_STATE_WORKING and e.status.message.message_id + ] + assert len(working) == 1 + wire_msg = working[0].status.message + wire_parts = [MessageToDict(p) for p in wire_msg.parts] + wire_meta = MessageToDict(wire_msg.metadata) + + # Draft's 3 parts are the prefix; AgentMessage adds "checkpoint" as the 4th + assert wire_parts[:3] == draft_1["parts"] + assert wire_parts[3]["text"] == "checkpoint" + assert wire_meta == draft_1["metadata"] + + # -- Artifact event (bypasses accumulator) -- + artifact_events = [ + event for event in events_1 + if isinstance(event[0], StreamResponse) and event[0].artifact_update.artifact.artifact_id + ] + assert len(artifact_events) == 1 + assert artifact_events[0][0].artifact_update.artifact.name == "data.txt" + + # -- Cycle 2: Apply patches to build the draft -- + # Yields: "post-artifact" (then InputRequired flushes) + draft_2 = apply_patches(cycles[1]) + assert draft_2["parts"][0]["text"] == "post-artifact" + + # The INPUT_REQUIRED wire message = merge(draft_2, InputRequired("what next?")) + input_required = [ + e for e in status_events + if e.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + ] + assert len(input_required) == 1 + ir_parts = [MessageToDict(p) for p in input_required[0].status.message.parts] + assert ir_parts[:1] == draft_2["parts"] + assert ir_parts[1]["text"] == "what next?" + + # --- Second send: resume --- + task_id = events_1[-1][1].id + resume = create_text_message_object(content="hello again") + resume.task_id = task_id + + events_2 = [] + async for event in client.send_message(SendMessageRequest(message=resume), context=STREAMING_CONTEXT): + events_2.append(event) + + status_events_2 = extract_status_events(events_2) + patches_2 = extract_streaming_patches(events_2) + + # Cycle 3: Apply patches — should match the COMPLETED wire message exactly + draft_3 = apply_patches(patches_2) + completed = [e for e in status_events_2 if e.status.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + wire_completed_parts = [MessageToDict(p) for p in completed[0].status.message.parts] + assert wire_completed_parts == draft_3["parts"] + + +# --- StreamingExtensionClient tests --- + + +def _make_streaming_client() -> StreamingExtensionClient: + return StreamingExtensionClient(StreamingExtensionSpec()) + + +async def test_streaming_client_text_deltas(streaming_string_agent): + """Verify StreamingExtensionClient emits TextDelta for streamed text chunks.""" + _, client = streaming_string_agent + streaming = _make_streaming_client() + + text_deltas = [] + state_changes = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + match delta: + case TextDelta() as td: + text_deltas.append(td.delta) + case PartDelta(): + pass + case StateChange() as sc: + state_changes.append(sc) + + # Should have text deltas for " beautiful" and " world" (first chunk is PartDelta from root replace) + assert len(text_deltas) >= 2 + + # Verify final state change is COMPLETED + completed = [sc for sc in state_changes if sc.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + # The completed message should have been reconciled (already streamed) + assert completed[0].message is not None + + +async def test_streaming_client_part_delta(streaming_part_agent): + """Verify StreamingExtensionClient emits PartDelta for explicit Part yields.""" + _, client = streaming_part_agent + streaming = _make_streaming_client() + + part_deltas = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + match delta: + case PartDelta() as pd: + part_deltas.append(pd) + case _: + pass + + assert len(part_deltas) >= 1 + assert part_deltas[0].part["text"] == "explicit part" + + +async def test_streaming_client_without_extension(no_streaming_string_agent): + """Verify StreamingExtensionClient works without streaming extension (decompose full messages).""" + _, client = no_streaming_string_agent + streaming = _make_streaming_client() + + part_deltas = [] + state_changes = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")))): + match delta: + case PartDelta() as pd: + part_deltas.append(pd) + case StateChange() as sc: + state_changes.append(sc) + case _: + pass + + # Without streaming, the completed message should be decomposed into PartDelta + completed = [sc for sc in state_changes if sc.state == TaskState.TASK_STATE_COMPLETED] + assert len(completed) == 1 + # The message parts should appear as PartDelta events + assert any(pd.part.get("text") == "Hello beautiful world" for pd in part_deltas) + + +async def test_streaming_client_reconciles_streamed_messages(streaming_mixed_agent): + """Verify that messages already streamed via patches are properly reconciled.""" + _, client = streaming_mixed_agent + streaming = _make_streaming_client() + + all_deltas = [] + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + all_deltas.append(delta) + + # Should have deltas from streaming (text1, text2) and then the explicit AgentMessage + text_deltas = [d for d in all_deltas if isinstance(d, TextDelta)] + part_deltas = [d for d in all_deltas if isinstance(d, PartDelta)] + state_changes = [d for d in all_deltas if isinstance(d, StateChange)] + + # The streaming patches produce text/part deltas for "text1" and "text2" + assert len(text_deltas) + len(part_deltas) >= 1 + + # Should end with COMPLETED state + assert any(sc.state == TaskState.TASK_STATE_COMPLETED for sc in state_changes) + + +async def test_streaming_client_message_id_tracking(streaming_string_agent): + """Verify message_id is tracked from streaming patches.""" + _, client = streaming_string_agent + streaming = _make_streaming_client() + + async for delta, task in streaming.stream(client.send_message(SendMessageRequest(message=create_text_message_object(content="hi")), context=STREAMING_CONTEXT)): + pass + + # After stream completes, the draft should have been used + # message_id should have been set during streaming + # (it gets cleared after reconciliation, so we check the completed state) diff --git a/apps/adk-py/tests/e2e/test_yields.py b/apps/adk-py/tests/e2e/test_yields.py index 7b1f347c..143797ae 100644 --- a/apps/adk-py/tests/e2e/test_yields.py +++ b/apps/adk-py/tests/e2e/test_yields.py @@ -236,11 +236,13 @@ async def test_sync_function_with_context_agent(sync_function_with_context_agent assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED - # Should have intermediate yield and final result + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "first sync yield" in messages - assert "sync_function_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "first sync yield" in text + assert "sync_function_with_context_agent: hello" in text async def test_sync_generator_agent(sync_generator_agent): @@ -252,10 +254,13 @@ async def test_sync_generator_agent(sync_generator_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator yield 1" in messages - assert "sync_generator yield 2" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "sync_generator yield 1" in text + assert "sync_generator yield 2" in text async def test_sync_generator_with_context_agent(sync_generator_with_context_agent): @@ -267,12 +272,15 @@ async def test_sync_generator_with_context_agent(sync_generator_with_context_age assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator_with_context yield 1" in messages - assert "sync_generator_with_context context yield" in messages - assert "sync_generator_with_context yield 2" in messages - assert "sync_generator_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "sync_generator_with_context yield 1" in text + assert "sync_generator_with_context context yield" in text + assert "sync_generator_with_context yield 2" in text + assert "sync_generator_with_context_agent: hello" in text async def test_async_function_agent(async_function_agent): @@ -297,10 +305,13 @@ async def test_async_function_with_context_agent(async_function_with_context_age assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "first async yield" in messages - assert "async_function_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "first async yield" in text + assert "async_function_with_context_agent: hello" in text async def test_async_generator_agent(async_generator_agent): @@ -312,11 +323,14 @@ async def test_async_generator_agent(async_generator_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator yield 1" in messages - assert "async_generator yield 2" in messages - assert "async_generator_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "async_generator yield 1" in text + assert "async_generator yield 2" in text + assert "async_generator_agent: hello" in text async def test_async_generator_with_context_agent(async_generator_with_context_agent): @@ -328,12 +342,15 @@ async def test_async_generator_with_context_agent(async_generator_with_context_a assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # Consecutive string yields are accumulated into a single message # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator_with_context yield 1" in messages - assert "async_generator_with_context context yield" in messages - assert "async_generator_with_context yield 2" in messages - assert "async_generator_with_context_agent: hello" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + text = agent_messages[0].parts[0].text + assert "async_generator_with_context yield 1" in text + assert "async_generator_with_context context yield" in text + assert "async_generator_with_context yield 2" in text + assert "async_generator_with_context_agent: hello" in text async def test_sync_function_resume_agent(sync_function_resume_agent): @@ -369,9 +386,10 @@ async def test_sync_generator_resume_agent(sync_generator_resume_agent): assert initial_task is not None assert initial_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + # The "starting" string is flushed as a message before the InputRequired status # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in initial_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator_resume_agent: starting" in messages + agent_messages = [msg for msg in initial_task.history if msg.role == Role.ROLE_AGENT] + assert any("sync_generator_resume_agent: starting" in msg.parts[0].text for msg in agent_messages) # Resume with additional data resume_message = create_text_message_object(content="resume data") @@ -383,8 +401,8 @@ async def test_sync_generator_resume_agent(sync_generator_resume_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "sync_generator_resume_agent: received resume data" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert any("sync_generator_resume_agent: received resume data" in msg.parts[0].text for msg in agent_messages) async def test_async_function_resume_agent(async_function_resume_agent): @@ -421,9 +439,10 @@ async def test_async_generator_resume_agent(async_generator_resume_agent): assert initial_task is not None assert initial_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + # The "starting" string is flushed as a message before the InputRequired status # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in initial_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator_resume_agent: starting" in messages + agent_messages = [msg for msg in initial_task.history if msg.role == Role.ROLE_AGENT] + assert any("async_generator_resume_agent: starting" in msg.parts[0].text for msg in agent_messages) # Resume with additional data resume_message = create_text_message_object(content="resume data") @@ -435,8 +454,8 @@ async def test_async_generator_resume_agent(async_generator_resume_agent): assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED # pyrefly: ignore [missing-attribute, not-iterable] - messages = [msg.parts[0].text for msg in final_task.history if msg.role == Role.ROLE_AGENT] - assert "async_generator_resume_agent: received resume data" in messages + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert any("async_generator_resume_agent: received resume data" in msg.parts[0].text for msg in agent_messages) async def test_sync_function_streaming(sync_function_agent): @@ -457,7 +476,7 @@ async def test_sync_function_streaming(sync_function_agent): async def test_sync_generator_streaming(sync_generator_agent): - """Test synchronous generator agent with streaming to see intermediate yields.""" + """Test synchronous generator agent with streaming events.""" _, client = sync_generator_agent events = [] async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hello"))): @@ -472,13 +491,14 @@ async def test_sync_generator_streaming(sync_generator_agent): assert len(status_events) > 0 assert status_events[-1].status.state == TaskState.TASK_STATE_COMPLETED - # Should see multiple working state messages for each yield - working_events = [e for e in status_events if e.status.state == TaskState.TASK_STATE_WORKING] - assert len(working_events) >= 3 # At least 3 yields from the generator + # Without the streaming extension activated by the client, partial updates are not sent. + # The final completed message contains the accumulated result. + completed = status_events[-1] + assert completed.status.message.parts[0].text async def test_async_generator_streaming(async_generator_agent): - """Test asynchronous generator agent with streaming to see intermediate yields.""" + """Test asynchronous generator agent with streaming events.""" _, client = async_generator_agent events = [] async for event in client.send_message(SendMessageRequest(message=create_text_message_object(content="hello"))): @@ -492,13 +512,15 @@ async def test_async_generator_streaming(async_generator_agent): assert len(status_events) > 0 assert status_events[-1].status.state == TaskState.TASK_STATE_COMPLETED - # Should see multiple working state messages for each yield - working_events = [e for e in status_events if e.status.state == TaskState.TASK_STATE_WORKING] - assert len(working_events) >= 2 # At least 2 yields from the generator + # Without the streaming extension activated by the client, partial updates are not sent. + # The final completed message contains the accumulated result. + completed = status_events[-1] + assert completed.status.message.parts[0].text async def test_yield_dict_vs_metadata(create_server_with_agent): async def yielder_of_meta_data() -> AsyncIterator[RunYield]: + # dict → Part(data=...), Metadata, and AgentMessage are accumulated into one message yield {"data": "this should be datapart"} yield Metadata({"metadata": "this should be metadata"}) yield AgentMessage( @@ -513,26 +535,24 @@ async def yielder_of_meta_data() -> AsyncIterator[RunYield]: assert final_task is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + # dict+Metadata are accumulated, then flushed as draft when AgentMessage arrives + # The final message merges the draft (data part + metadata) with the AgentMessage (metadata) + # pyrefly: ignore [missing-attribute, not-iterable] + agent_messages = [msg for msg in final_task.history if msg.role == Role.ROLE_AGENT] + assert len(agent_messages) == 1 + merged = agent_messages[0] # pyrefly: ignore [missing-attribute, unsupported-operation] - assert MessageToDict(final_task.history[0].parts[0].data) == {"data": "this should be datapart"} + assert MessageToDict(merged.parts[0].data) == {"data": "this should be datapart"} + # Metadata from both Metadata yield and AgentMessage are merged # pyrefly: ignore [unsupported-operation] - assert MessageToDict(final_task.history[1].metadata) == {"metadata": "this should be metadata"} - # pyrefly: ignore [unsupported-operation] - assert MessageToDict(final_task.history[2].metadata) == { - "metadata": "this class still behaves as dict", - "metadata2": "and can be used in union", - } - # pyrefly: ignore [unsupported-operation] - assert not final_task.history[0].metadata - # pyrefly: ignore [unsupported-operation] - assert not final_task.history[1].parts - # pyrefly: ignore [unsupported-operation] - assert not final_task.history[2].parts + merged_meta = MessageToDict(merged.metadata) + assert merged_meta["metadata"] == "this class still behaves as dict" + assert merged_meta["metadata2"] == "and can be used in union" async def test_yield_of_all_types(create_server_with_agent): async def yielder_of_all_types_agent(message: Message, context: RunContext) -> AsyncIterator[RunYield]: - """Synchronous function agent that returns a string directly.""" + """Agent that yields all supported types.""" text_part = Part(text="text") message = AgentMessage(parts=[text_part], role=Role.ROLE_AGENT, message_id=str(uuid.uuid4())) yield message @@ -571,5 +591,10 @@ async def yielder_of_all_types_agent(message: Message, context: RunContext) -> A _, ) if artifact_id: artifact_cnt += 1 - assert message_cnt == 9 + # With streaming accumulation: + # - Message yield → 1 message + # - text_part accumulated, then TaskStatus flushes → 1 merged message + # - Parts accumulated, then TaskStatusUpdateEvent flushes → 1 merged message + # - str/dict/Metadata accumulated, then completion flushes → 1 message + assert message_cnt == 4 assert artifact_cnt == 2 diff --git a/apps/adk-py/tests/test_merge_utils.py b/apps/adk-py/tests/test_merge_utils.py new file mode 100644 index 00000000..015943a4 --- /dev/null +++ b/apps/adk-py/tests/test_merge_utils.py @@ -0,0 +1,60 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC + +from kagenti_adk.server.utils import _merge_recursive, merge_metadata +from kagenti_adk.a2a.types import Metadata +import pytest + + +def test_scalar_merge(): + assert _merge_recursive(1, 2) == 2 + assert _merge_recursive("a", "b") == "b" + assert _merge_recursive(True, False) == False + assert _merge_recursive(None, 1) == 1 + + +def test_list_merge(): + assert _merge_recursive([1, 2], [3, 4]) == [1, 2, 3, 4] + assert _merge_recursive([], [1]) == [1] + assert _merge_recursive([1], []) == [1] + + +def test_dict_merge_simple(): + a = {"x": 1} + b = {"y": 2} + expected = {"x": 1, "y": 2} + assert _merge_recursive(a, b) == expected + + +def test_dict_merge_overwrite(): + a = {"x": 1} + b = {"x": 2} + expected = {"x": 2} + assert _merge_recursive(a, b) == expected + + +def test_dict_merge_recursive(): + a = {"x": 1, "y": {"a": 1}} + b = {"y": {"b": 2}, "z": 3} + expected = {"x": 1, "y": {"a": 1, "b": 2}, "z": 3} + assert _merge_recursive(a, b) == expected + + +def test_dict_merge_nested_list(): + a = {"x": [1]} + b = {"x": [2]} + expected = {"x": [1, 2]} + assert _merge_recursive(a, b) == expected + + +def test_merge_metadata(): + m1 = Metadata(foo="bar") + m2 = Metadata(baz="qux") + merged = merge_metadata(m1, m2) + assert merged == Metadata(foo="bar", baz="qux") + + +def test_merge_metadata_recursive(): + m1 = Metadata(nested={"a": 1}) + m2 = Metadata(nested={"b": 2}) + merged = merge_metadata(m1, m2) + assert merged == Metadata(nested={"a": 1, "b": 2}) diff --git a/apps/adk-py/tests/unit/server/test_accumulator.py b/apps/adk-py/tests/unit/server/test_accumulator.py new file mode 100644 index 00000000..0cd9c979 --- /dev/null +++ b/apps/adk-py/tests/unit/server/test_accumulator.py @@ -0,0 +1,292 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +from a2a.types import Part, Role, TaskState, TaskStatus, TaskStatusUpdateEvent + +from kagenti_adk.a2a.types import AgentMessage, Metadata +from kagenti_adk.server.accumulator import MessageAccumulator, MessageContext, TextPartContext + +pytestmark = pytest.mark.unit + + +# --- State transition tests --- + + +class TestStateTransitions: + def test_initial_state_is_base_level(self): + acc = MessageAccumulator() + assert isinstance(acc.active_context, MessageAccumulator) + + def test_string_enters_text_part_context(self): + acc = MessageAccumulator() + result = acc.process("hello") + assert result.accumulated is True + assert isinstance(acc.active_context, TextPartContext) + + def test_consecutive_strings_stay_in_text_part(self): + acc = MessageAccumulator() + acc.process("a") + result = acc.process("b") + assert result.accumulated is True + assert isinstance(acc.active_context, TextPartContext) + + def test_part_enters_message_context(self): + acc = MessageAccumulator() + result = acc.process(Part(text="x")) + assert result.accumulated is True + assert isinstance(acc.active_context, MessageContext) + + def test_metadata_enters_message_context(self): + acc = MessageAccumulator() + result = acc.process(Metadata({"key": "val"})) + assert result.accumulated is True + assert isinstance(acc.active_context, MessageContext) + + def test_part_after_string_transitions_to_message_context(self): + acc = MessageAccumulator() + acc.process("text chunk") + result = acc.process(Part(text="explicit part")) + assert result.accumulated is True + assert isinstance(acc.active_context, MessageContext) + # The text chunk was built into a Part and added to the MessageContext + ctx = acc.active_context + assert len(ctx.parts) == 2 + assert ctx.parts[0].text == "text chunk" + assert ctx.parts[1].text == "explicit part" + + def test_message_passthrough_at_base_level(self): + acc = MessageAccumulator() + msg = AgentMessage(text="hello") + result = acc.process(msg) + assert result.accumulated is False + assert result.draft is None + assert isinstance(acc.active_context, MessageAccumulator) + + def test_task_status_passthrough_at_base_level(self): + acc = MessageAccumulator() + status = TaskStatus(state=TaskState.TASK_STATE_WORKING) + result = acc.process(status) + assert result.accumulated is False + assert result.draft is None + assert isinstance(acc.active_context, MessageAccumulator) + + def test_task_status_update_event_passthrough_at_base_level(self): + acc = MessageAccumulator() + event = TaskStatusUpdateEvent( + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + task_id="t1", + context_id="c1", + ) + result = acc.process(event) + assert result.accumulated is False + assert isinstance(acc.active_context, MessageAccumulator) + + def test_message_flushes_accumulated_text(self): + acc = MessageAccumulator() + acc.process("a") + acc.process("b") + msg = AgentMessage(text="final") + result = acc.process(msg) + assert result.accumulated is False + assert result.draft is not None + assert result.draft.parts[0].text == "ab" + assert isinstance(acc.active_context, MessageAccumulator) + + def test_task_status_flushes_accumulated_parts(self): + acc = MessageAccumulator() + acc.process(Part(text="hello")) + status = TaskStatus(state=TaskState.TASK_STATE_WORKING) + result = acc.process(status) + assert result.accumulated is False + assert result.draft is not None + assert result.draft.parts[0].text == "hello" + assert isinstance(acc.active_context, MessageAccumulator) + + def test_message_flushes_accumulated_text_and_parts(self): + acc = MessageAccumulator() + acc.process("text chunk") + acc.process(Part(text="explicit part")) + acc.process(Metadata({"key": "val"})) + msg = AgentMessage(text="final") + result = acc.process(msg) + assert result.accumulated is False + assert result.draft is not None + # Draft should contain the accumulated text part and the explicit part + assert len(result.draft.parts) == 2 + assert result.draft.parts[0].text == "text chunk" + assert result.draft.parts[1].text == "explicit part" + assert isinstance(acc.active_context, MessageAccumulator) + + def test_string_after_part_enters_text_part_context(self): + acc = MessageAccumulator() + acc.process(Part(text="first")) + assert isinstance(acc.active_context, MessageContext) + acc.process("streaming text") + assert isinstance(acc.active_context, TextPartContext) + + def test_input_required_flushes_text(self): + acc = MessageAccumulator() + acc.process("thinking...") + status = TaskStatus(state=TaskState.TASK_STATE_INPUT_REQUIRED) + result = acc.process(status) + assert result.accumulated is False + assert result.draft is not None + assert result.draft.parts[0].text == "thinking..." + + +# --- Patch verification tests --- + + +class TestPatchOutput: + def test_first_string_produces_replace_root_patch(self): + acc = MessageAccumulator() + result = acc.process("Hello") + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [{"text": "Hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_subsequent_string_produces_str_ins_patch(self): + acc = MessageAccumulator() + acc.process("Hello") + result = acc.process(" world") + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "str_ins" + assert patch["path"] == "/parts/0/text" + assert patch["pos"] == 5 + assert patch["value"] == " world" + + def test_str_ins_path_uses_correct_index_after_parts(self): + acc = MessageAccumulator() + acc.process(Part(text="first")) + acc.process(Part(text="second")) + acc.process("stream") # add patch at index 2 + result = acc.process("ing") # str_ins at index 2 + assert result.patch is not None + assert result.patch[0]["path"] == "/parts/2/text" + + def test_first_part_produces_replace_root_patch(self): + acc = MessageAccumulator() + result = acc.process(Part(text="hello")) + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] + assert value["parts"] == [{"text": "hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_first_metadata_produces_replace_root_patch(self): + acc = MessageAccumulator() + result = acc.process(Metadata({"key": "val"})) + assert result.patch is not None + assert len(result.patch) == 1 + patch = result.patch[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] + assert value["parts"] == [] # type: ignore[index] + assert value["metadata"] == {"key": "val"} # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_second_metadata_produces_incremental_patches(self): + acc = MessageAccumulator() + acc.process(Metadata({"ext://test": [{"ref": "a"}]})) + result = acc.process(Metadata({"ext://test": [{"ref": "b"}]})) + assert result.patch is not None + # Should produce incremental add patch, not a full replace + assert len(result.patch) >= 1 + # The patches should target /metadata/... paths + for op in result.patch: + assert op["path"].startswith("/metadata/") + # Apply all patches to verify correctness + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + draft = {"parts": [], "metadata": {"ext://test": [{"ref": "a"}]}} + draft = ExtendedJsonPatch(result.patch).apply(draft) + assert draft["metadata"]["ext://test"] == [{"ref": "a"}, {"ref": "b"}] + + def test_message_id_propagated_in_all_accumulated_results(self): + acc = MessageAccumulator() + r1 = acc.process("Hello") + assert r1.message_id is not None + r2 = acc.process(" world") + assert r2.message_id == r1.message_id + r3 = acc.process(Part(text="part")) + assert r3.message_id == r1.message_id + r4 = acc.process(Metadata({"k": "v"})) + assert r4.message_id == r1.message_id + + def test_passthrough_has_no_patch(self): + acc = MessageAccumulator() + result = acc.process(AgentMessage(text="hello")) + assert result.patch is None + + +# --- Flush tests --- + + +class TestFlush: + def test_flush_from_text_part_returns_message(self): + acc = MessageAccumulator() + acc.process("hello") + acc.process(" world") + msg = acc.flush() + assert msg is not None + assert msg.role == Role.ROLE_AGENT + assert msg.parts[0].text == "hello world" + + def test_flush_from_message_context_returns_message(self): + acc = MessageAccumulator() + acc.process(Part(text="a")) + acc.process(Part(text="b")) + msg = acc.flush() + assert msg is not None + assert len(msg.parts) == 2 + assert msg.parts[0].text == "a" + assert msg.parts[1].text == "b" + + def test_flush_from_base_level_returns_none(self): + acc = MessageAccumulator() + assert acc.flush() is None + + def test_flush_resets_to_base_level(self): + acc = MessageAccumulator() + acc.process("hello") + assert isinstance(acc.active_context, TextPartContext) + acc.flush() + assert isinstance(acc.active_context, MessageAccumulator) + + def test_flush_after_passthrough_returns_none(self): + acc = MessageAccumulator() + acc.process("hello") + acc.process(AgentMessage(text="explicit")) # flushes internally + assert isinstance(acc.active_context, MessageAccumulator) + assert acc.flush() is None + + def test_flush_with_metadata(self): + acc = MessageAccumulator() + acc.process(Part(text="content")) + acc.process(Metadata({"key": "value"})) + msg = acc.flush() + assert msg is not None + assert msg.parts[0].text == "content" + from google.protobuf.json_format import MessageToDict + + assert MessageToDict(msg.metadata) == {"key": "value"} + + def test_double_flush_returns_none(self): + acc = MessageAccumulator() + acc.process("hello") + msg = acc.flush() + assert msg is not None + assert acc.flush() is None diff --git a/apps/adk-py/tests/unit/server/test_context.py b/apps/adk-py/tests/unit/server/test_context.py new file mode 100644 index 00000000..b31a5966 --- /dev/null +++ b/apps/adk-py/tests/unit/server/test_context.py @@ -0,0 +1,188 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +from a2a.types import Part, Role +from google.protobuf.json_format import MessageToDict + +from kagenti_adk.a2a.types import Metadata +from kagenti_adk.server.accumulator import MessageContext, TextPartContext + +pytestmark = pytest.mark.unit + + +# --- TextPartContext --- + + +class TestTextPartContext: + def _make(self, part_index: int = 0) -> TextPartContext: + return TextPartContext(message_context=MessageContext(), part_index=part_index) + + def test_first_chunk_produces_replace_root_patch(self): + ctx = self._make() + patches = ctx.add_chunk("Hello") + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [{"text": "Hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_subsequent_chunk_produces_str_ins(self): + ctx = self._make() + ctx.add_chunk("Hello") + patches = ctx.add_chunk(" world") + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "str_ins" + assert patch["path"] == "/parts/0/text" + assert patch["value"] == " world" + assert patch["pos"] == 5 + + def test_pos_advances_by_chunk_length(self): + ctx = self._make() + assert ctx.pos == 0 + ctx.add_chunk("abc") + assert ctx.pos == 3 + ctx.add_chunk("de") + assert ctx.pos == 5 + ctx.add_chunk("f") + assert ctx.pos == 6 + + def test_build_concatenates_chunks(self): + ctx = self._make() + ctx.add_chunk("Hello") + ctx.add_chunk(" ") + ctx.add_chunk("world") + part = ctx.build() + assert part.text == "Hello world" + + def test_str_ins_uses_part_index(self): + ctx = self._make(part_index=3) + ctx.add_chunk("Hello") + patches = ctx.add_chunk(" world") + assert patches[0]["path"] == "/parts/3/text" + + def test_build_empty_chunks(self): + ctx = self._make() + part = ctx.build() + assert part.text == "" + + +# --- MessageContext --- + + +class TestMessageContext: + def test_first_add_part_returns_replace_root_patch(self): + ctx = MessageContext() + patches = ctx.add_part(Part(text="hello")) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [{"text": "hello"}] # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_second_add_part_returns_add_patch(self): + ctx = MessageContext() + ctx.add_part(Part(text="first")) + patches = ctx.add_part(Part(text="second")) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "add" + assert patch["path"] == "/parts/-" + assert patch["value"]["text"] == "second" + + def test_add_part_accumulates(self): + ctx = MessageContext() + ctx.add_part(Part(text="a")) + ctx.add_part(Part(text="b")) + assert len(ctx.parts) == 2 + assert ctx.parts[0].text == "a" + assert ctx.parts[1].text == "b" + + def test_first_add_metadata_returns_replace_root_patch(self): + ctx = MessageContext() + patches = ctx.add_metadata(Metadata({"key": "value"})) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "replace" + assert patch["path"] == "" + value = patch["value"] # type: ignore[typeddict-item] + assert value["parts"] == [] # type: ignore[index] + assert value["metadata"] == {"key": "value"} # type: ignore[index] + assert "message_id" in value # type: ignore[operator] + + def test_add_metadata_after_part_returns_add_patch(self): + ctx = MessageContext() + ctx.add_part(Part(text="hello")) + patches = ctx.add_metadata(Metadata({"key": "value"})) + assert len(patches) == 1 + patch = patches[0] + assert patch["op"] == "add" + assert patch["path"] == "/metadata" + assert patch["value"] == {"key": "value"} + + def test_add_metadata_deep_merges(self): + ctx = MessageContext() + ctx.add_metadata(Metadata({"a": 1})) + ctx.add_metadata(Metadata({"b": 2})) + assert ctx.metadata == {"a": 1, "b": 2} + + def test_add_metadata_second_call_returns_incremental_patches(self): + ctx = MessageContext() + ctx.add_metadata(Metadata({"a": 1})) + patches = ctx.add_metadata(Metadata({"b": 2})) + assert len(patches) >= 1 + # Should produce an add for the new key, not a full replace + assert all(op["path"].startswith("/metadata/") for op in patches) + # Verify the patches are correct by applying them + from kagenti_adk.server.jsonpatch_ext import ExtendedJsonPatch + draft = {"metadata": {"a": 1}} + draft = ExtendedJsonPatch(patches).apply(draft) + assert draft["metadata"] == {"a": 1, "b": 2} + + def test_add_metadata_deep_merges_nested(self): + ctx = MessageContext() + ctx.add_metadata(Metadata({"ext": {"x": 1}})) + ctx.add_metadata(Metadata({"ext": {"y": 2}})) + assert ctx.metadata == {"ext": {"x": 1, "y": 2}} + + def test_add_metadata_concatenates_arrays(self): + """Extensions using arrays (e.g. citations, trajectory) should accumulate via concatenation.""" + uri = "https://example.com/ext/v1" + ctx = MessageContext() + ctx.add_metadata(Metadata({uri: [{"title": "a"}]})) + ctx.add_metadata(Metadata({uri: [{"title": "b"}]})) + assert ctx.metadata == {uri: [{"title": "a"}, {"title": "b"}]} + + def test_build_creates_agent_message(self): + ctx = MessageContext() + ctx.add_part(Part(text="hello")) + ctx.add_metadata(Metadata({"key": "value"})) + msg = ctx.build() + assert msg.role == Role.ROLE_AGENT + assert msg.message_id == str(ctx.message_id) + assert len(msg.parts) == 1 + assert msg.parts[0].text == "hello" + assert MessageToDict(msg.metadata) == {"key": "value"} + + def test_build_without_metadata(self): + ctx = MessageContext() + ctx.add_part(Part(text="hello")) + msg = ctx.build() + assert msg.role == Role.ROLE_AGENT + assert len(msg.parts) == 1 + assert not MessageToDict(msg.metadata) + + def test_build_with_multiple_parts(self): + ctx = MessageContext() + ctx.add_part(Part(text="a")) + ctx.add_part(Part(text="b")) + ctx.add_part(Part(text="c")) + msg = ctx.build() + assert [p.text for p in msg.parts] == ["a", "b", "c"] diff --git a/apps/adk-py/tests/unit/server/test_jsonpatch_ext.py b/apps/adk-py/tests/unit/server/test_jsonpatch_ext.py new file mode 100644 index 00000000..b8a81857 --- /dev/null +++ b/apps/adk-py/tests/unit/server/test_jsonpatch_ext.py @@ -0,0 +1,78 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC + +import jsonpatch +import pytest + +from kagenti_adk.server.jsonpatch_ext import StrInsOperation, make_patch + + +def test_str_ins_append(): + """Test optimized append detection""" + src = {"text": "Hello"} + dst = {"text": "Hello World"} + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "str_ins" + assert ops[0]["path"] == "/text" + assert ops[0]["pos"] == 5 + assert ops[0]["value"] == " World" + + # Test apply + assert patch.apply(src) == dst + + +def test_str_ins_middle(): + """Test middle insertion detection via difflib""" + src = {"text": "Hello World"} + dst = {"text": "Hello beautiful World"} + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "str_ins" + assert ops[0]["path"] == "/text" + assert ops[0]["value"] == "beautiful " + assert ops[0]["pos"] == 6 + + assert patch.apply(src) == dst + + +def test_str_ins_prepend(): + """Test prepend detection""" + src = {"text": "World"} + dst = {"text": "Hello World"} + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "str_ins" + assert ops[0]["pos"] == 0 + assert ops[0]["value"] == "Hello " + + assert patch.apply(src) == dst + + +def test_str_ins_complex_fallback(): + """Test that multiple changes fallback to replace""" + src = {"text": "Hello World"} + dst = {"text": "Hi World!"} # Changed Start and End + patch = make_patch(src, dst) + ops = list(patch) + assert len(ops) == 1 + assert ops[0]["op"] == "replace" + + assert patch.apply(src) == dst + + +def test_str_ins_explicit_apply(): + """Test manual StrInsOperation application""" + obj = {"foo": ["bar"]} + op = StrInsOperation({"op": "str_ins", "path": "/foo/0", "pos": 1, "value": "az"}) + res = op.apply(obj) + assert res["foo"][0] == "bazar" + + +def test_str_ins_out_of_bounds(): + obj = {"text": "foo"} + op = StrInsOperation({"op": "str_ins", "path": "/text", "pos": 100, "value": "bar"}) + with pytest.raises(jsonpatch.JsonPatchConflict): + op.apply(obj) diff --git a/apps/adk-py/tests/unit/test_agent_detail_population.py b/apps/adk-py/tests/unit/test_agent_detail_population.py index a01179bd..89cf8b36 100644 --- a/apps/adk-py/tests/unit/test_agent_detail_population.py +++ b/apps/adk-py/tests/unit/test_agent_detail_population.py @@ -42,8 +42,11 @@ def test_agent_fn(): # Check tools populated from skills assert "tools" in params tools = params["tools"] + # pyrefly: ignore [bad-argument-type] assert len(tools) == 1 + # pyrefly: ignore [unsupported-operation, bad-index] assert tools[0]["name"] == "test_skill" + # pyrefly: ignore [unsupported-operation, bad-index] assert tools[0]["description"] == "A test skill" # Check user_greeting populated from description @@ -88,7 +91,9 @@ def test_agent_fn(): # Tools should still be populated because custom_detail.tools was None assert "tools" in params tools = params["tools"] + # pyrefly: ignore [bad-argument-type] assert len(tools) == 1 + # pyrefly: ignore [unsupported-operation, bad-index] assert tools[0]["name"] == "test_skill" # Greeting should be custom diff --git a/apps/adk-py/tests/unit/test_dependencies.py b/apps/adk-py/tests/unit/test_dependencies.py index 53c353db..362b1075 100644 --- a/apps/adk-py/tests/unit/test_dependencies.py +++ b/apps/adk-py/tests/unit/test_dependencies.py @@ -7,18 +7,26 @@ import pytest -from kagenti_adk.a2a.extensions import CitationExtensionServer, CitationExtensionSpec +from kagenti_adk.a2a.extensions import ( + CitationExtensionServer, + CitationExtensionSpec, + TrajectoryExtensionServer, + TrajectoryExtensionSpec, +) +from kagenti_adk.a2a.extensions.streaming import StreamingExtensionServer, StreamingExtensionSpec from kagenti_adk.server.dependencies import extract_dependencies +pytestmark = pytest.mark.unit + class MyExtensions(TypedDict): a: Annotated[CitationExtensionServer, CitationExtensionSpec()] - b: Annotated[CitationExtensionServer, CitationExtensionSpec()] - c: Annotated[CitationExtensionServer, CitationExtensionSpec()] + b: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()] + c: Annotated[StreamingExtensionServer, StreamingExtensionSpec()] class MyExtensionsComplex(TypedDict): - b: Annotated[CitationExtensionServer, CitationExtensionSpec()] + b: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()] @pytest.mark.unit @@ -29,7 +37,6 @@ def agent(a: Annotated[CitationExtensionServer, CitationExtensionSpec()]) -> Non assert extract_dependencies(agent).keys() == {"a"} -@pytest.mark.unit def test_extract_dependencies_extra_parameters() -> None: def agent(a: Annotated[CitationExtensionServer, CitationExtensionSpec()], b: bool) -> None: pass @@ -48,8 +55,8 @@ def agent(**kwargs: Unpack[MyExtensions]) -> None: assert extract_dependencies(agent).keys() == {"a", "b", "c"} -@pytest.mark.unit def test_extract_dependencies_complex() -> None: + def agent( a: Annotated[CitationExtensionServer, CitationExtensionSpec()], **kwargs: Unpack[MyExtensionsComplex], diff --git a/apps/adk-py/uv.lock b/apps/adk-py/uv.lock index 98d1744e..18af3e37 100644 --- a/apps/adk-py/uv.lock +++ b/apps/adk-py/uv.lock @@ -1119,6 +1119,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonref" version = "1.1.0" @@ -1162,6 +1183,7 @@ source = { editable = "." } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -1169,6 +1191,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -1198,6 +1221,7 @@ dev = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -1205,6 +1229,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-server/pyproject.toml b/apps/adk-server/pyproject.toml index d813fcbb..8d03076a 100644 --- a/apps/adk-server/pyproject.toml +++ b/apps/adk-server/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "opentelemetry-instrumentation-openai>=0.52.1", "opentelemetry-instrumentation-sqlalchemy>=0.60b1", "aws-bedrock-token-generator>=1.1.0", + "jsonpatch>=1.33", ] [dependency-groups] diff --git a/apps/adk-server/src/adk_server/api/routes/a2a.py b/apps/adk-server/src/adk_server/api/routes/a2a.py index cb390345..b14d15c5 100644 --- a/apps/adk-server/src/adk_server/api/routes/a2a.py +++ b/apps/adk-server/src/adk_server/api/routes/a2a.py @@ -100,7 +100,7 @@ async def get_agent_card( card_copy = create_proxy_agent_card( provider.agent_card, provider_id=provider.id, request=request, configuration=configuration ) - return MessageToDict(card_copy, preserving_proto_field_name=True) + return MessageToDict(card_copy) @router.post("/{provider_id}") diff --git a/apps/adk-server/tests/e2e/agents/test_platform_extensions.py b/apps/adk-server/tests/e2e/agents/test_platform_extensions.py index 54bbb147..ceca3e5c 100644 --- a/apps/adk-server/tests/e2e/agents/test_platform_extensions.py +++ b/apps/adk-server/tests/e2e/agents/test_platform_extensions.py @@ -8,8 +8,8 @@ from uuid import uuid4 import pytest -from a2a.client import Client, create_text_message_object -from a2a.types import SendMessageRequest, Message, Role, TaskState +from a2a.client import Client +from a2a.types import Message, Role, SendMessageRequest, TaskState from kagenti_adk.a2a.extensions.services.platform import ( PlatformApiExtensionClient, PlatformApiExtensionServer, @@ -126,4 +126,3 @@ async def test_self_registration(self_registration_agent, subtests): assert provider.state == "online" assert "self_registration_agent" in provider.source - diff --git a/apps/adk-server/uv.lock b/apps/adk-server/uv.lock index 970842d6..dd304e58 100644 --- a/apps/adk-server/uv.lock +++ b/apps/adk-server/uv.lock @@ -46,6 +46,7 @@ dependencies = [ { name = "httpx" }, { name = "ibm-watsonx-ai" }, { name = "ijson" }, + { name = "jsonpatch" }, { name = "kink" }, { name = "kr8s" }, { name = "limits", extra = ["async-redis"] }, @@ -106,6 +107,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "ibm-watsonx-ai", specifier = ">=1.3.28" }, { name = "ijson", specifier = ">=3.4.0.post0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "kink", specifier = ">=0.8.1" }, { name = "kr8s", specifier = ">=0.20.7" }, { name = "limits", extras = ["async-redis"], specifier = ">=5.3.0" }, @@ -1203,6 +1205,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -1250,6 +1273,7 @@ source = { editable = "../adk-py" } dependencies = [ { name = "a2a-sdk", extra = ["sqlite"] }, { name = "anyio" }, + { name = "asgiref" }, { name = "async-lru" }, { name = "asyncclick" }, { name = "authlib" }, @@ -1257,6 +1281,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx" }, { name = "janus" }, + { name = "jsonpatch" }, { name = "mcp" }, { name = "objprint" }, { name = "opentelemetry-api" }, @@ -1276,6 +1301,7 @@ dependencies = [ requires-dist = [ { name = "a2a-sdk", extras = ["sqlite"], specifier = "==1.0.0a0" }, { name = "anyio", specifier = ">=4.9.0" }, + { name = "asgiref", specifier = ">=3.11.0" }, { name = "async-lru", specifier = ">=2.0.4" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "authlib", specifier = ">=1.3.0" }, @@ -1283,6 +1309,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx" }, { name = "janus", specifier = ">=2.0.0" }, + { name = "jsonpatch", specifier = ">=1.33" }, { name = "mcp", specifier = ">=1.12.3" }, { name = "objprint", specifier = ">=0.3.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, diff --git a/apps/adk-ts/src/client/a2a/extensions/index.ts b/apps/adk-ts/src/client/a2a/extensions/index.ts index dc565d55..e0f4cfda 100644 --- a/apps/adk-ts/src/client/a2a/extensions/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/index.ts @@ -19,4 +19,5 @@ export * from './ui/citation'; export * from './ui/error'; export * from './ui/form-request'; export * from './ui/settings'; +export * from './ui/streaming'; export * from './ui/trajectory'; diff --git a/apps/adk-ts/src/client/a2a/extensions/types.ts b/apps/adk-ts/src/client/a2a/extensions/types.ts index 730bd2e8..36713071 100644 --- a/apps/adk-ts/src/client/a2a/extensions/types.ts +++ b/apps/adk-ts/src/client/a2a/extensions/types.ts @@ -17,4 +17,5 @@ export * from './ui/canvas/types'; export * from './ui/citation/types'; export * from './ui/error/types'; export * from './ui/settings/types'; +export * from './ui/streaming/types'; export * from './ui/trajectory/types'; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts index c18c34de..e88140a7 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/citation/index.ts @@ -6,12 +6,12 @@ import z from 'zod'; import type { A2AUiExtension } from '../../../../core/extensions/types'; -import { citationMetadataSchema } from './schemas'; -import type { CitationMetadata } from './types'; +import { citationSchema } from './schemas'; +import type { Citation } from './types'; export const CITATION_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/citation/v1'; -export const citationExtension: A2AUiExtension = { +export const citationExtension: A2AUiExtension = { getUri: () => CITATION_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [CITATION_EXTENSION_URI]: citationMetadataSchema }).partial(), + getMessageMetadataSchema: () => z.object({ [CITATION_EXTENSION_URI]: z.array(citationSchema) }).partial(), }; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts new file mode 100644 index 00000000..d1186faf --- /dev/null +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/index.ts @@ -0,0 +1,20 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +import type { A2AUiExtension } from '../../../../core/extensions/types'; +import { streamingMetadataSchema } from './schemas'; +import type { StreamingMetadata } from './types'; + +export type { StreamingMetadata, StreamingPatch } from './types'; +export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1'; + +export const streamingExtension: A2AUiExtension = { + getUri: () => STREAMING_EXTENSION_URI, + getMessageMetadataSchema: () => z.object({ [STREAMING_EXTENSION_URI]: streamingMetadataSchema }).partial(), +}; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts new file mode 100644 index 00000000..3cce4408 --- /dev/null +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +export const streamingPatchSchema = z.object({ + op: z.string(), + path: z.string(), + value: z.unknown().optional(), + pos: z.number().optional(), // for str_ins +}); + +export const streamingMetadataSchema = z.object({ + message_update: z.array(streamingPatchSchema).optional(), + message_id: z.string().optional(), +}); diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts new file mode 100644 index 00000000..aa10398f --- /dev/null +++ b/apps/adk-ts/src/client/a2a/extensions/ui/streaming/types.ts @@ -0,0 +1,12 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type z from 'zod'; + +import type { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export type StreamingPatch = z.infer; + +export type StreamingMetadata = z.infer; diff --git a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts index d18b230d..7ab8aac8 100644 --- a/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts +++ b/apps/adk-ts/src/client/a2a/extensions/ui/trajectory/index.ts @@ -11,7 +11,7 @@ import type { TrajectoryMetadata } from './types'; export const TRAJECTORY_EXTENSION_URI = 'https://a2a-extensions.adk.kagenti.dev/ui/trajectory/v1'; -export const trajectoryExtension: A2AUiExtension = { +export const trajectoryExtension: A2AUiExtension = { getUri: () => TRAJECTORY_EXTENSION_URI, - getMessageMetadataSchema: () => z.object({ [TRAJECTORY_EXTENSION_URI]: trajectoryMetadataSchema }).partial(), + getMessageMetadataSchema: () => z.object({ [TRAJECTORY_EXTENSION_URI]: z.array(trajectoryMetadataSchema) }).partial(), }; diff --git a/apps/adk-ui/src/api/a2a/agent-card.ts b/apps/adk-ui/src/api/a2a/agent-card.ts index 5bbc46ac..2ab80317 100644 --- a/apps/adk-ui/src/api/a2a/agent-card.ts +++ b/apps/adk-ui/src/api/a2a/agent-card.ts @@ -19,7 +19,8 @@ export async function getAgentClient(providerId: string, token: string): Promise const agentCard = await fetchAgentCard(agentCardUrl, fetchImpl); - return createA2AClient({ endpointUrl, agentCard, fetchImpl }); + const extensions = agentCard.capabilities?.extensions?.map((ext) => ext.uri).filter(Boolean) as string[]; + return createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions }); } async function clientFetch(input: RequestInfo, init?: RequestInit) { diff --git a/apps/adk-ui/src/api/a2a/client.ts b/apps/adk-ui/src/api/a2a/client.ts index fbe1b796..0b3afe09 100644 --- a/apps/adk-ui/src/api/a2a/client.ts +++ b/apps/adk-ui/src/api/a2a/client.ts @@ -17,6 +17,7 @@ import { getAgentClient } from './agent-card'; import { AGENT_ERROR_MESSAGE } from './constants'; import type { A2AClient } from './jsonrpc-client'; import { processArtifactMetadata, processMessageMetadata, processParts } from './part-processors'; +import { applyPatches, extractStreamingPatches } from './streaming'; import type { ChatResult, TaskStatusUpdateResultWithTaskId } from './types'; import { type ChatParams, type ChatRun, RunResultType } from './types'; import { createUserMessage, extractErrorExtension } from './utils'; @@ -124,6 +125,7 @@ export const buildA2AClient = async ({ const messageSubject = new Subject>(); let taskId: undefined | TaskId = initialTaskId; + const streamingDraft: Record = {}; const iterateOverStream = async () => { const agentCardMetadata = await resolveAgentCardMetadata(fulfillments); @@ -153,6 +155,28 @@ export const buildA2AClient = async ({ .with({ statusUpdate: P.nonNullable }, ({ statusUpdate }) => { taskId = statusUpdate.taskId; + // Check for streaming patches in metadata + const patches = extractStreamingPatches( + statusUpdate.metadata as Record | undefined, + ); + + if (patches && taskId) { + // Apply patches to draft and emit as a replace update + applyPatches(streamingDraft, patches); + const draftParts = (streamingDraft.parts as Array>) ?? []; + const uiParts: UIMessagePart[] = draftParts + .map((part): UIMessagePart | null => { + if (typeof part.text === 'string') { + return { kind: UIMessagePartKind.Text, id: 'streaming-text', text: part.text } as UITextPart; + } + return null; + }) + .filter((p): p is UIMessagePart => p !== null); + + messageSubject.next({ type: RunResultType.Parts, parts: uiParts, taskId, replace: true }); + return; + } + handleTaskStatusUpdate(statusUpdate).forEach((result) => { if (!taskId) { throw new Error(`Illegal State - taskId missing on status-update event`); @@ -166,10 +190,17 @@ export const buildA2AClient = async ({ const parts: (UIMessagePart | UIGenericPart)[] = handleStatusUpdate(statusUpdate, onStatusUpdate); + // On final message, clear the streaming draft and replace the + // streamed text so it is not duplicated by the complete message. + const wasStreaming = Object.keys(streamingDraft).length > 0; + if (wasStreaming) { + Object.keys(streamingDraft).forEach((key) => delete streamingDraft[key]); + } + if (!taskId) { throw new Error(`Illegal State - taskId missing on status-update event`); } - messageSubject.next({ type: RunResultType.Parts, parts, taskId }); + messageSubject.next({ type: RunResultType.Parts, parts, taskId, replace: wasStreaming }); }) .with({ artifactUpdate: P.nonNullable }, ({ artifactUpdate }) => { taskId = artifactUpdate.taskId; diff --git a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts index 5c901b4c..9083bd49 100644 --- a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts +++ b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts @@ -23,13 +23,19 @@ interface CreateClientParams { endpointUrl: string; agentCard: AgentCard; fetchImpl: typeof fetch; + extensions?: string[]; } -export function createA2AClient({ endpointUrl, agentCard, fetchImpl }: CreateClientParams): A2AClient { +export function createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions }: CreateClientParams): A2AClient { + const requestHeaders: Record = { 'Content-Type': 'application/json' }; + if (extensions?.length) { + requestHeaders['X-A2A-Extensions'] = extensions.join(','); + } + async function jsonRpcRequest(method: string, params: Record) { const response = await fetchImpl(endpointUrl, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: requestHeaders, body: JSON.stringify({ jsonrpc: '2.0', id: uuid(), @@ -59,7 +65,7 @@ export function createA2AClient({ endpointUrl, agentCard, fetchImpl }: CreateCli async *sendMessageStream(params) { const response = await fetchImpl(endpointUrl, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: requestHeaders, body: JSON.stringify({ jsonrpc: '2.0', id: uuid(), diff --git a/apps/adk-ui/src/api/a2a/part-processors.ts b/apps/adk-ui/src/api/a2a/part-processors.ts index f04aa591..e181cf59 100644 --- a/apps/adk-ui/src/api/a2a/part-processors.ts +++ b/apps/adk-ui/src/api/a2a/part-processors.ts @@ -21,12 +21,14 @@ import { export function processMessageMetadata(message: Message): UIMessagePart[] { const trajectory = extractTrajectory(message.metadata); - const citations = extractCitation(message.metadata)?.citations; + const citations = extractCitation(message.metadata); const parts: UIMessagePart[] = []; if (trajectory) { - parts.push(createTrajectoryPart(trajectory)); + for (const item of trajectory) { + parts.push(createTrajectoryPart(item)); + } } if (citations) { const sourceParts = citations.map((citation) => createSourcePart(citation, message.taskId)).filter(isNotNull); @@ -38,7 +40,7 @@ export function processMessageMetadata(message: Message): UIMessagePart[] { } export function processArtifactMetadata(artifact: Artifact, taskId: string): UISourcePart[] { - const citations = extractCitation(artifact.metadata)?.citations; + const citations = extractCitation(artifact.metadata); if (!citations) { return []; diff --git a/apps/adk-ui/src/api/a2a/streaming.ts b/apps/adk-ui/src/api/a2a/streaming.ts new file mode 100644 index 00000000..bdf03d1d --- /dev/null +++ b/apps/adk-ui/src/api/a2a/streaming.ts @@ -0,0 +1,136 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { StreamingMetadata, StreamingPatch } from '@kagenti/adk'; +import { STREAMING_EXTENSION_URI } from '@kagenti/adk'; + +/** + * Extract streaming patches from status update metadata. + */ +export function extractStreamingPatches( + metadata: Record | undefined | null, +): StreamingPatch[] | null { + if (!metadata) return null; + const streamingData = metadata[STREAMING_EXTENSION_URI] as StreamingMetadata | undefined; + if (!streamingData?.message_update?.length) return null; + return streamingData.message_update; +} + +/** + * Resolve a JSON pointer path to get/set a value in a nested object. + * Returns [parent, key] for the target location. + */ +function resolvePath(obj: Record, path: string): [Record, string] { + if (path === '' || path === '/') return [obj, '']; + + const parts = path.split('/').filter(Boolean); + let current: unknown = obj; + + for (let i = 0; i < parts.length - 1; i++) { + const key = parts[i]; + if (Array.isArray(current)) { + current = current[Number(key)]; + } else if (current && typeof current === 'object') { + current = (current as Record)[key]; + } + } + + return [current as Record, parts[parts.length - 1]]; +} + +function getByPath(obj: Record, path: string): unknown { + if (path === '' || path === '/') return obj; + const parts = path.split('/').filter(Boolean); + let current: unknown = obj; + for (const key of parts) { + if (Array.isArray(current)) { + current = current[Number(key)]; + } else if (current && typeof current === 'object') { + current = (current as Record)[key]; + } else { + return undefined; + } + } + return current; +} + +/** Clone only objects; primitives (strings, numbers, booleans) are immutable and don't need cloning. */ +function cloneValue(v: T): T { + return typeof v === 'object' && v !== null ? structuredClone(v) : v; +} + +/** + * Apply a single streaming patch to a draft message object. + * Supports: replace, add, str_ins (custom string insertion). + */ +function applyPatch(draft: Record, patch: StreamingPatch): void { + const { op, path, value } = patch; + + if (op === 'replace') { + if (path === '' || path === '/') { + // Root replace — overwrite entire draft + Object.keys(draft).forEach((key) => delete draft[key]); + if (value && typeof value === 'object') { + Object.assign(draft, cloneValue(value)); + } + } else { + const [parent, key] = resolvePath(draft, path); + if (parent && typeof parent === 'object') { + if (Array.isArray(parent)) { + parent[Number(key)] = cloneValue(value); + } else { + parent[key] = cloneValue(value); + } + } + } + } else if (op === 'add') { + if (path.endsWith('/-')) { + // Append to array + const arrayPath = path.slice(0, -2); + const arr = getByPath(draft, arrayPath); + if (Array.isArray(arr)) { + arr.push(cloneValue(value)); + } + } else { + const [parent, key] = resolvePath(draft, path); + if (parent && typeof parent === 'object') { + if (Array.isArray(parent)) { + parent.splice(Number(key), 0, cloneValue(value)); + } else { + parent[key] = cloneValue(value); + } + } + } + } else if (op === 'str_ins') { + // Custom operation: insert string at position + const pos = patch.pos ?? 0; + const current = getByPath(draft, path); + if (typeof current === 'string' && typeof value === 'string') { + const newValue = current.slice(0, pos) + value + current.slice(pos); + const [parent, key] = resolvePath(draft, path); + if (parent && typeof parent === 'object') { + if (Array.isArray(parent)) { + parent[Number(key)] = newValue; + } else { + parent[key] = newValue; + } + } + } + } +} + +/** + * Apply an array of streaming patches to a draft message object. + * Mutates the draft in place and returns it. + */ +export function applyPatches( + draft: Record, + patches: StreamingPatch[], +): Record { + for (const patch of patches) { + applyPatch(draft, patch); + } + return draft; +} diff --git a/apps/adk-ui/src/api/a2a/types.ts b/apps/adk-ui/src/api/a2a/types.ts index 89aa5dff..d7077eb1 100644 --- a/apps/adk-ui/src/api/a2a/types.ts +++ b/apps/adk-ui/src/api/a2a/types.ts @@ -18,6 +18,8 @@ export interface PartsResult { type: RunResultType.Parts; taskId: TaskId; parts: Array; + /** When true, replace all current parts instead of appending (used for streaming updates). */ + replace?: boolean; } export type TaskStatusUpdateResultWithTaskId = TaskStatusUpdateResult & { @@ -37,7 +39,9 @@ export interface ChatParams { export interface ChatRun { taskId?: TaskId; done: Promise; - subscribe: (fn: (data: { parts: (UIMessagePart | UIGenericPart)[]; taskId: TaskId }) => void) => () => void; + subscribe: ( + fn: (data: { parts: (UIMessagePart | UIGenericPart)[]; taskId: TaskId; replace?: boolean }) => void, + ) => () => void; cancel: () => Promise; } diff --git a/apps/adk-ui/src/api/a2a/utils.ts b/apps/adk-ui/src/api/a2a/utils.ts index 4e520ddf..303ee976 100644 --- a/apps/adk-ui/src/api/a2a/utils.ts +++ b/apps/adk-ui/src/api/a2a/utils.ts @@ -9,6 +9,7 @@ import { citationExtension, errorExtension, extractUiExtensionData, + streamingExtension, trajectoryExtension, type TrajectoryMetadata, } from '@kagenti/adk'; @@ -32,6 +33,7 @@ import { PLATFORM_FILE_CONTENT_URL_BASE } from './constants'; export const extractCitation = extractUiExtensionData(citationExtension); export const extractTrajectory = extractUiExtensionData(trajectoryExtension); export const extractErrorExtension = extractUiExtensionData(errorExtension); +export const extractStreamingMessage = extractUiExtensionData(streamingExtension); export function convertMessageParts(uiParts: UIMessagePart[]): Part[] { const parts: Part[] = uiParts diff --git a/apps/adk-ui/src/api/adk-client.ts b/apps/adk-ui/src/api/adk-client.ts index 750fc15f..82792178 100644 --- a/apps/adk-ui/src/api/adk-client.ts +++ b/apps/adk-ui/src/api/adk-client.ts @@ -9,8 +9,6 @@ import { ensureToken } from '#app/(auth)/rsc.tsx'; import { runtimeConfig } from '#contexts/App/runtime-config.ts'; import { getBaseUrl } from '#utils/api/getBaseUrl.ts'; -import { getProxyHeaders } from './utils'; - function buildAuthenticatedAdkClient() { const { isAuthEnabled } = runtimeConfig; const baseUrl = getBaseUrl(); @@ -26,19 +24,6 @@ function buildAuthenticatedAdkClient() { } } - const isServer = typeof window === 'undefined'; - - if (isServer) { - const { headers } = await import('next/headers'); - const { forwarded, forwardedHost, forwardedFor, forwardedProto } = await getProxyHeaders(await headers()); - - request.headers.set('forwarded', forwarded); - - if (forwardedHost) request.headers.set('x-forwarded-host', forwardedHost); - if (forwardedProto) request.headers.set('x-forwarded-proto', forwardedProto); - if (forwardedFor) request.headers.set('x-forwarded-for', forwardedFor); - } - const response = await fetch(request); return response; diff --git a/apps/adk-ui/src/app/(auth)/auth.ts b/apps/adk-ui/src/app/(auth)/auth.ts index a464dbf6..b8755bc2 100644 --- a/apps/adk-ui/src/app/(auth)/auth.ts +++ b/apps/adk-ui/src/app/(auth)/auth.ts @@ -60,7 +60,7 @@ export function getProvider(): ProviderWithId | null { const issuer = process.env.OIDC_PROVIDER_ISSUER; const externalIssuer = process.env.OIDC_PROVIDER_EXTERNAL_ISSUER; - if (!name || !id || !clientId || !clientSecret || !issuer) { + if (!name || !id || !clientId || !issuer) { throw new Error( 'Missing OIDC provider configuration. Set OIDC_PROVIDER_NAME, OIDC_PROVIDER_ID, OIDC_PROVIDER_CLIENT_ID, OIDC_PROVIDER_CLIENT_SECRET, and OIDC_PROVIDER_ISSUER.', ); diff --git a/apps/adk-ui/src/modules/platform-context/constants.ts b/apps/adk-ui/src/modules/platform-context/constants.ts index abe8243a..15364f35 100644 --- a/apps/adk-ui/src/modules/platform-context/constants.ts +++ b/apps/adk-ui/src/modules/platform-context/constants.ts @@ -16,7 +16,7 @@ export const contextTokenPermissionsDefaults: DeepRequired) { pendingRun.current = run; let isFirstIteration = true; - pendingSubscription.current = run.subscribe(({ parts, taskId: responseTaskId }) => { + pendingSubscription.current = run.subscribe(({ parts, taskId: responseTaskId, replace }) => { if (isFirstIteration) { queryClient.invalidateQueries({ queryKey: contextKeys.lists() }); } updateCurrentAgentMessage((message) => { message.taskId = responseTaskId; - }); - parts.forEach((part) => { - updateCurrentAgentMessage((message) => { - const updatedParts = addMessagePart(part, message); - message.parts = updatedParts; - }); + if (replace) { + // Streaming update: replace text parts with latest draft + const nonTextParts = message.parts.filter((p) => p.kind !== UIMessagePartKind.Text); + message.parts = [...parts, ...nonTextParts]; + } else { + for (const part of parts) { + message.parts = addMessagePart(part, message); + } + } }); isFirstIteration = false; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts new file mode 100644 index 00000000..d1186faf --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/index.ts @@ -0,0 +1,20 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +import type { A2AUiExtension } from '../../../../core/extensions/types'; +import { streamingMetadataSchema } from './schemas'; +import type { StreamingMetadata } from './types'; + +export type { StreamingMetadata, StreamingPatch } from './types'; +export { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export const STREAMING_EXTENSION_URI = 'https://a2a-extensions.agentstack.beeai.dev/ui/streaming/v1'; + +export const streamingExtension: A2AUiExtension = { + getUri: () => STREAMING_EXTENSION_URI, + getMessageMetadataSchema: () => z.object({ [STREAMING_EXTENSION_URI]: streamingMetadataSchema }).partial(), +}; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts new file mode 100644 index 00000000..3cce4408 --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/schemas.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +export const streamingPatchSchema = z.object({ + op: z.string(), + path: z.string(), + value: z.unknown().optional(), + pos: z.number().optional(), // for str_ins +}); + +export const streamingMetadataSchema = z.object({ + message_update: z.array(streamingPatchSchema).optional(), + message_id: z.string().optional(), +}); diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts new file mode 100644 index 00000000..aa10398f --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/ui/streaming/types.ts @@ -0,0 +1,12 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type z from 'zod'; + +import type { streamingMetadataSchema, streamingPatchSchema } from './schemas'; + +export type StreamingPatch = z.infer; + +export type StreamingMetadata = z.infer; From ab4cbf243af4169654dcf0ab688237d1be3cc604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Thu, 26 Mar 2026 16:02:16 +0100 Subject: [PATCH 2/2] feat: replace custom ContextStore with A2A-native task history MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use A2A Tasks as the single source of truth for conversation history, eliminating the duplicate ContextStore abstraction. adk-server: - Implement on_list_tasks in A2A proxy (forwards to agent transport) - Remove context history endpoints (POST/GET/DELETE /contexts/{id}/history) - Remove title generation job and worker - Remove history-related service, repository, and schema code adk-py: - Delete ContextStore, InMemoryContextStore, PlatformContextStore - Rewire RunContext to use TaskStore (load_history reads from task store) - Remove store(), store_sync(), delete_history_from_id() from RunContext - Remove context_store parameter from create_app() and Server.serve() - Remove ContextHistoryItem and history methods from platform SDK adk-ui: - Add listTasks to A2A JSON-RPC client - Create fetchTasksForContext for server-side task fetching - Add convertTasksToUIMessages for task-to-UI conversion - Rewire AgentRun/PlatformContextProvider/MessagesProvider to use tasks - Remove useListContextHistory, context history API functions and types Closes #116 Assisted-By: Claude (Anthropic AI) Signed-off-by: Radek Ježek --- agents/canvas/src/canvas/agent.py | 7 - agents/chat/src/chat/agent.py | 22 +-- .../src/content_builder/agent.py | 4 - agents/rag/src/rag/agent.py | 5 - apps/adk-py/examples/canvas_ui_code_agent.py | 6 - apps/adk-py/examples/canvas_ui_test_agent.py | 6 - apps/adk-py/examples/citation_agent.py | 7 +- .../examples/citation_agent_artifact.py | 8 +- apps/adk-py/examples/history.py | 6 - apps/adk-py/examples/history_framework.py | 5 - apps/adk-py/examples/oauth.py | 5 - apps/adk-py/examples/trajectory_agent.py | 20 +-- .../src/kagenti_adk/platform/context.py | 94 +----------- apps/adk-py/src/kagenti_adk/server/agent.py | 14 +- apps/adk-py/src/kagenti_adk/server/app.py | 7 +- apps/adk-py/src/kagenti_adk/server/context.py | 60 +++----- apps/adk-py/src/kagenti_adk/server/server.py | 4 - .../kagenti_adk/server/store/context_store.py | 36 ----- .../server/store/memory_context_store.py | 61 -------- .../server/store/platform_context_store.py | 53 ------- apps/adk-py/tests/e2e/conftest.py | 5 - apps/adk-py/tests/e2e/test_history.py | 108 ++++---------- .../src/adk_server/api/routes/contexts.py | 36 +---- .../src/adk_server/api/schema/contexts.py | 5 +- .../adk_server/domain/repositories/context.py | 16 +-- .../persistence/repositories/context.py | 108 +------------- .../src/adk_server/jobs/procrastinate.py | 2 - apps/adk-server/src/adk_server/jobs/queues.py | 1 - .../src/adk_server/jobs/tasks/context.py | 18 --- apps/adk-server/src/adk_server/run_workers.py | 5 - .../adk_server/service_layer/services/a2a.py | 3 +- .../service_layer/services/contexts.py | 135 +---------------- apps/adk-server/tests/e2e/agents/conftest.py | 5 - .../tests/e2e/agents/test_context_store.py | 115 --------------- .../tests/e2e/routes/test_contexts.py | 136 ------------------ apps/adk-ui/src/api/a2a/jsonrpc-client.ts | 19 +++ apps/adk-ui/src/api/a2a/list-tasks.ts | 79 ++++++++++ apps/adk-ui/src/modules/history/utils.ts | 50 ++++++- .../contexts/Messages/MessagesProvider.tsx | 64 ++------- .../contexts/Messages/messages-context.ts | 9 +- .../modules/platform-context/api/constants.ts | 4 +- .../src/modules/platform-context/api/index.ts | 13 -- .../src/modules/platform-context/api/keys.ts | 5 +- .../api/queries/useListContextHistory.ts | 60 -------- .../src/modules/platform-context/api/types.ts | 5 - .../src/modules/platform-context/api/utils.ts | 10 -- .../contexts/PlatformContextProvider.tsx | 8 +- .../contexts/platform-context.ts | 9 +- .../src/modules/runs/components/AgentRun.tsx | 16 +-- .../contexts/agent-run/AgentRunProvider.tsx | 1 - docs/development/agent-integration/canvas.mdx | 3 - .../agent-integration/multi-turn.mdx | 107 +------------- docs/development/agent-integration/rag.mdx | 1 - .../src/canvas_with_llm/agent.py | 3 - .../src/advanced_history/agent.py | 5 - .../basic-history/src/basic_history/agent.py | 6 - .../src/streaming_agent_history/agent.py | 10 -- .../src/conversation_rag_agent/agent.py | 1 - skills/kagenti-adk-wrapper/SKILL.md | 9 +- .../references/wrapper-entrypoint.md | 5 +- 60 files changed, 252 insertions(+), 1378 deletions(-) delete mode 100644 apps/adk-py/src/kagenti_adk/server/store/context_store.py delete mode 100644 apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py delete mode 100644 apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py delete mode 100644 apps/adk-server/tests/e2e/agents/test_context_store.py create mode 100644 apps/adk-ui/src/api/a2a/list-tasks.ts delete mode 100644 apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts delete mode 100644 apps/adk-ui/src/modules/platform-context/api/utils.ts diff --git a/agents/canvas/src/canvas/agent.py b/agents/canvas/src/canvas/agent.py index 7c9ba8b1..a4e9308f 100644 --- a/agents/canvas/src/canvas/agent.py +++ b/agents/canvas/src/canvas/agent.py @@ -50,7 +50,6 @@ async def canvas_agent( yield "Can't run without a LLM." return - await context.store(message) edit_request = await canvas.parse_canvas_edit_request(message=message) user_text_content = _get_text(message) @@ -132,12 +131,6 @@ async def canvas_agent( parts=[TextPart(text=content_delta)], ) - final_artifact = AgentArtifact( - artifact_id=artifact.artifact_id, - name=artifact.name, - parts=[TextPart(text=buffer)], - ) - await context.store(final_artifact) def serve(): diff --git a/agents/chat/src/chat/agent.py b/agents/chat/src/chat/agent.py index 66a868cb..b7321f84 100644 --- a/agents/chat/src/chat/agent.py +++ b/agents/chat/src/chat/agent.py @@ -24,11 +24,10 @@ PlatformApiExtensionServer, PlatformApiExtensionSpec, ) -from kagenti_adk.a2a.types import AgentArtifact, AgentMessage +from kagenti_adk.a2a.types import AgentArtifact from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext from kagenti_adk.server.middleware.platform_auth_backend import PlatformAuthBackend -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.events import ( @@ -36,7 +35,7 @@ RequirementAgentSuccessEvent, ) from beeai_framework.agents.requirement.utils._tool import FinalAnswerTool -from beeai_framework.backend import AssistantMessage, ChatModelParameters +from beeai_framework.backend import ChatModelParameters from beeai_framework.errors import FrameworkError from beeai_framework.middleware.trajectory import GlobalTrajectoryMiddleware from beeai_framework.tools import AnyTool, Tool @@ -45,7 +44,6 @@ from beeai_framework.tools.weather import OpenMeteoTool from openinference.instrumentation.beeai import BeeAIInstrumentor -from chat.helpers.citations import extract_citations from chat.helpers.trajectory import TrajectoryContent from chat.tools.files.file_creator import FileCreatorTool, FileCreatorToolOutput from chat.tools.files.file_reader import FileReaderTool @@ -149,8 +147,6 @@ async def chat( _p: Annotated[PlatformApiExtensionServer, PlatformApiExtensionSpec()], ): """Agent with memory and access to web search, Wikipedia, and weather.""" - await context.store(input) - # Send initial trajectory yield trajectory.trajectory_metadata(title="Starting", content="Received your request") @@ -220,7 +216,6 @@ async def chat( middlewares=[GlobalTrajectoryMiddleware(included=[Tool])], ) - final_answer: AssistantMessage | None = None new_messages = [to_framework_message(item, extracted_files) for item in history] try: @@ -244,8 +239,6 @@ async def chat( case RequirementAgentFinalAnswerEvent(delta=delta): yield delta case RequirementAgentSuccessEvent(state=state): - final_answer = state.answer - last_step = state.steps[-1] if last_step.tool and last_step.tool.name == FinalAnswerTool.name: # internal tool continue @@ -259,7 +252,6 @@ async def chat( group_id=last_step.id, ) yield metadata - await context.store(AgentMessage(metadata=metadata)) if isinstance(last_step.output, FileCreatorToolOutput): for file_info in last_step.output.result.files: @@ -267,16 +259,7 @@ async def chat( part.filename = file_info.display_filename artifact = AgentArtifact(name=file_info.display_filename, parts=[part]) yield artifact - await context.store(artifact) - - if final_answer: - citations, clean_text = extract_citations(final_answer.text) - message = AgentMessage( - text=clean_text, - metadata=(citation.citation_metadata(citations=citations) if citations else None), - ) - await context.store(message) except FrameworkError as err: raise RuntimeError(err.explain()) @@ -287,7 +270,6 @@ def serve(): host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), configure_telemetry=True, - context_store=PlatformContextStore(), auth_backend=PlatformAuthBackend(), ) except KeyboardInterrupt: diff --git a/agents/deepagents_content_builder/src/content_builder/agent.py b/agents/deepagents_content_builder/src/content_builder/agent.py index 82fac179..e713d395 100644 --- a/agents/deepagents_content_builder/src/content_builder/agent.py +++ b/agents/deepagents_content_builder/src/content_builder/agent.py @@ -91,7 +91,6 @@ async def content_builder_agent( return started_at = datetime.now(timezone.utc) - await context.store(data=message) subagents: list[SubAgent] = [] for sub_agent in AVAILABLE_SUBAGENTS: @@ -140,7 +139,6 @@ async def content_builder_agent( title=data["name"], content=json.dumps(obj=data["args"]) ) yield tool_call_metadata - await context.store(data=AgentMessage(metadata=tool_call_metadata)) tool_calls.clear() elif last_msg.tool_call_chunks: @@ -151,12 +149,10 @@ async def content_builder_agent( tool_calls[tc_id]["args"] += tc.get("args") or "" elif last_msg.text: yield AgentMessage(text=last_msg.text) - await context.store(AgentMessage(text=last_msg.text)) elif isinstance(last_msg, ToolMessage) and last_msg.name and last_msg.text: tool_message_metadata = trajectory.trajectory_metadata(title=last_msg.name, content=last_msg.text) yield tool_message_metadata - await context.store(data=AgentMessage(metadata=tool_message_metadata)) updated_files = await agent_stack_backend.alist(order_by="created_at", order="asc", created_after=started_at) for updated_file in updated_files: diff --git a/agents/rag/src/rag/agent.py b/agents/rag/src/rag/agent.py index 0199a396..36e35828 100644 --- a/agents/rag/src/rag/agent.py +++ b/agents/rag/src/rag/agent.py @@ -28,7 +28,6 @@ from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext from kagenti_adk.server.middleware.platform_auth_backend import PlatformAuthBackend -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.utils._tool import FinalAnswerTool @@ -116,7 +115,6 @@ async def rag( _: Annotated[PlatformApiExtensionServer, PlatformApiExtensionSpec()], ): """RAG agent that retrieves and generates text based on user queries.""" - await context.store(input) llm, embedding = _get_clients(llm_ext, embedding_ext) history = [m async for m in context.load_history()] @@ -181,7 +179,6 @@ async def rag( phase="end", ).metadata(trajectory) yield vector_store_create_metadata - await context.store(AgentMessage(metadata=vector_store_create_metadata)) tools.append(cast(Tool, VectorSearchTool(vector_store_id=vector_store_id, embedding_function=embedding))) async for item in embed_all_files( @@ -300,7 +297,6 @@ async def handle_tool_success(event, meta): metadata=(citation.citation_metadata(citations=citations) if citations else None), ) yield message - await context.store(message) def _get_clients( @@ -331,7 +327,6 @@ def serve(): host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), configure_telemetry=True, - context_store=PlatformContextStore(), auth_backend=PlatformAuthBackend(), ) except KeyboardInterrupt: diff --git a/apps/adk-py/examples/canvas_ui_code_agent.py b/apps/adk-py/examples/canvas_ui_code_agent.py index 5af31601..91af8087 100644 --- a/apps/adk-py/examples/canvas_ui_code_agent.py +++ b/apps/adk-py/examples/canvas_ui_code_agent.py @@ -84,8 +84,6 @@ async def artifacts_agent( ): """Works with artifacts""" - await context.store(input) - canvas_edit_request = await canvas.parse_canvas_edit_request(message=input) if canvas_edit_request: @@ -126,7 +124,6 @@ async def artifacts_agent( if pre_text := response[: match.start()]: message = AgentMessage(text=pre_text) yield message - await context.store(message) await asyncio.sleep(1) @@ -153,7 +150,6 @@ async def artifacts_agent( name=artifact_name, parts=[TextPart(text=code_content)], ) - await context.store(artifact) # Send first chunk with artifact_id to establish the artifact first_artifact = AgentArtifact( @@ -171,13 +167,11 @@ async def artifacts_agent( parts=[TextPart(text=chunk)], ) yield chunk_artifact - await context.store(chunk_artifact) await asyncio.sleep(0.3) if post_text := response[match.end() :]: message = AgentMessage(text=post_text) yield message - await context.store(message) if __name__ == "__main__": diff --git a/apps/adk-py/examples/canvas_ui_test_agent.py b/apps/adk-py/examples/canvas_ui_test_agent.py index 87383535..7f407694 100644 --- a/apps/adk-py/examples/canvas_ui_test_agent.py +++ b/apps/adk-py/examples/canvas_ui_test_agent.py @@ -67,8 +67,6 @@ async def artifacts_agent( ): """Works with artifacts""" - await context.store(input) - canvas_edit_request = await canvas.parse_canvas_edit_request(message=input) if canvas_edit_request: @@ -106,7 +104,6 @@ async def artifacts_agent( if pre_text := response[: match.start()].strip(): message = AgentMessage(text=pre_text) yield message - await context.store(message) await asyncio.sleep(1) @@ -137,7 +134,6 @@ async def artifacts_agent( name=artifact_name, parts=[TextPart(text=recipe_content)], ) - await context.store(artifact) # Send first chunk with artifact_id to establish the artifact first_artifact = AgentArtifact( @@ -155,13 +151,11 @@ async def artifacts_agent( parts=[TextPart(text=chunk)], ) yield chunk_artifact - await context.store(chunk_artifact) await asyncio.sleep(0.3) if post_text := response[match.end() :]: message = AgentMessage(text=post_text) yield message - await context.store(message) if __name__ == "__main__": diff --git a/apps/adk-py/examples/citation_agent.py b/apps/adk-py/examples/citation_agent.py index 8929c053..7857f555 100644 --- a/apps/adk-py/examples/citation_agent.py +++ b/apps/adk-py/examples/citation_agent.py @@ -12,7 +12,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -27,9 +26,6 @@ async def example_agent( ): """Agent that demonstrates citation extension usage""" - # Store the current message in the context store - await context.store(input) - # Simulate researching multiple sources research_text = """Based on recent research, artificial intelligence has made significant progress in natural language processing. Studies show that transformer models have revolutionized the field, and @@ -60,12 +56,11 @@ async def example_agent( metadata=citation.citation_metadata(citations=citations), ) yield message - await context.store(message) def run(): server.run( - host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), context_store=PlatformContextStore() + host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)) ) diff --git a/apps/adk-py/examples/citation_agent_artifact.py b/apps/adk-py/examples/citation_agent_artifact.py index af089064..a92524cb 100644 --- a/apps/adk-py/examples/citation_agent_artifact.py +++ b/apps/adk-py/examples/citation_agent_artifact.py @@ -12,7 +12,6 @@ from kagenti_adk.a2a.types import AgentArtifact from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -27,9 +26,6 @@ async def example_agent( ): """Agent that demonstrates citation extension usage""" - # Store the current message in the context store - await context.store(input) - # Simulate researching multiple sources research_text = """Based on recent research, artificial intelligence has made significant progress in natural language processing. Studies show that transformer models have revolutionized the field, and @@ -62,12 +58,10 @@ async def example_agent( ) yield artifact - await context.store(artifact) - def run(): server.run( - host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8002)), context_store=PlatformContextStore() + host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8002)) ) diff --git a/apps/adk-py/examples/history.py b/apps/adk-py/examples/history.py index 27c0c3ca..76562e3b 100644 --- a/apps/adk-py/examples/history.py +++ b/apps/adk-py/examples/history.py @@ -19,9 +19,6 @@ async def example_agent(input: Message, context: RunContext): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - # Get the current user message current_message = get_message_text(input) print(f"Current message: {current_message}") @@ -36,9 +33,6 @@ async def example_agent(input: Message, context: RunContext): message = AgentMessage(text=f"Hello! I can see we have {len(history)} messages in our conversation.") yield message - # Store the message in the context store - await context.store(message) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) diff --git a/apps/adk-py/examples/history_framework.py b/apps/adk-py/examples/history_framework.py index 711fa0ab..46607fe6 100644 --- a/apps/adk-py/examples/history_framework.py +++ b/apps/adk-py/examples/history_framework.py @@ -18,7 +18,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -44,8 +43,6 @@ async def multi_turn_chat_agent( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # Load conversation history history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] @@ -81,14 +78,12 @@ async def multi_turn_chat_agent( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) diff --git a/apps/adk-py/examples/oauth.py b/apps/adk-py/examples/oauth.py index d3b2c9b0..511d4698 100644 --- a/apps/adk-py/examples/oauth.py +++ b/apps/adk-py/examples/oauth.py @@ -30,7 +30,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -57,8 +56,6 @@ async def oauth_agent( oauth: Annotated[OAuthExtensionServer, OAuthExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # pyrefly: ignore [deprecated] -- TODO: upgrade mcp_client = streamablehttp_client( url="https://mcp.stripe.com", @@ -104,14 +101,12 @@ async def oauth_agent( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) diff --git a/apps/adk-py/examples/trajectory_agent.py b/apps/adk-py/examples/trajectory_agent.py index 00c67226..291a59d0 100644 --- a/apps/adk-py/examples/trajectory_agent.py +++ b/apps/adk-py/examples/trajectory_agent.py @@ -13,7 +13,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -28,15 +27,11 @@ async def example_agent( ): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - metadata = trajectory.trajectory_metadata( title="Initializing...", content="Initializing...", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2.5) @@ -46,7 +41,6 @@ async def example_agent( content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(0.3) for i in range(4, 7): @@ -55,7 +49,6 @@ async def example_agent( content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(0.8) await asyncio.sleep(1) @@ -65,7 +58,6 @@ async def example_agent( content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2) metadata = trajectory.trajectory_metadata( @@ -134,7 +126,6 @@ def extract_entities(text): """, ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2) @@ -169,7 +160,6 @@ def extract_entities(text): }""", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(1) @@ -177,25 +167,21 @@ def extract_entities(text): title="Web search", content="Querying search engines...", group_id="websearch" ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(4) metadata = trajectory.trajectory_metadata(content="Found 8 results.", group_id="websearch") yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(1) metadata = trajectory.trajectory_metadata(content="Found 8 results\nAnalyzed 3/8 results", group_id="websearch") yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2) metadata = trajectory.trajectory_metadata(content="Found 8 results\nAnalyzed 8/8 results", group_id="websearch") yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(4) @@ -205,7 +191,6 @@ def extract_entities(text): group_id="websearch", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) # Your agent logic here - you can now reference all messages in the conversation message = AgentMessage( @@ -213,13 +198,10 @@ def extract_entities(text): ) yield message - # Store the message in the context store - await context.store(message) - def run(): server.run( - host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), context_store=PlatformContextStore() + host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)) ) diff --git a/apps/adk-py/src/kagenti_adk/platform/context.py b/apps/adk-py/src/kagenti_adk/platform/context.py index 51bba5e4..379af275 100644 --- a/apps/adk-py/src/kagenti_adk/platform/context.py +++ b/apps/adk-py/src/kagenti_adk/platform/context.py @@ -5,37 +5,15 @@ from __future__ import annotations import builtins -from collections.abc import AsyncIterator -from typing import Any, Literal, Self -from uuid import UUID, uuid4 - +from typing import Literal import pydantic -from a2a.types import Artifact, Message -from google.protobuf.json_format import MessageToDict, ParseDict -from pydantic import AwareDatetime, BaseModel, Field, SerializeAsAny, computed_field +from pydantic import SerializeAsAny from kagenti_adk.platform.client import PlatformClient, get_platform_client from kagenti_adk.platform.common import PaginatedResult from kagenti_adk.platform.provider import Provider from kagenti_adk.platform.types import Metadata, MetadataPatch -from kagenti_adk.util.utils import filter_dict, utc_now - - -class ContextHistoryItem(BaseModel, arbitrary_types_allowed=True): - id: UUID = Field(default_factory=uuid4) - data: Artifact | Message - created_at: AwareDatetime = Field(default_factory=utc_now) - context_id: str - - @computed_field - @property - def kind(self) -> Literal["message", "artifact"]: - return getattr(self.data, "kind", "artifact") - - @pydantic.field_validator("data", mode="before") - @classmethod - def parse_data(cls: Self, value: dict[str, Any]) -> Artifact | Message: - return ParseDict(value, Artifact() if "artifact_id" in value else Message()) +from kagenti_adk.util.utils import filter_dict class ContextToken(pydantic.BaseModel): @@ -231,69 +209,3 @@ async def generate_token( .json() ) return pydantic.TypeAdapter(ContextToken).validate_python({**token_response, "context_id": context_id}) - - async def add_history_item( - self: "Context" | str, - *, - data: Message | Artifact, - client: PlatformClient | None = None, - ) -> None: - """Add a Message or Artifact to the context history (append-only)""" - target_context_id = self if isinstance(self, str) else self.id - async with client or get_platform_client() as platform_client: - _ = ( - await platform_client.post( - url=f"/api/v1/contexts/{target_context_id}/history", json=MessageToDict(data) - ) - ).raise_for_status() - - async def delete_history_from_id( - self: "Context" | str, - *, - from_id: UUID | str, - client: PlatformClient | None = None, - ) -> None: - """Delete all history items from a specific item onwards (inclusive)""" - target_context_id = self if isinstance(self, str) else self.id - async with client or get_platform_client() as platform_client: - _ = ( - await platform_client.delete( - url=f"/api/v1/contexts/{target_context_id}/history", params={"from_id": str(from_id)} - ) - ).raise_for_status() - - async def list_history( - self: "Context" | str, - *, - page_token: str | None = None, - limit: int | None = None, - order: Literal["asc"] | Literal["desc"] | None = "asc", - order_by: Literal["created_at"] | Literal["updated_at"] | None = None, - client: PlatformClient | None = None, - ) -> PaginatedResult[ContextHistoryItem]: - """List all history items for this context in chronological order""" - target_context_id = self if isinstance(self, str) else self.id - async with client or get_platform_client() as platform_client: - return pydantic.TypeAdapter(PaginatedResult[ContextHistoryItem]).validate_python( - ( - await platform_client.get( - url=f"/api/v1/contexts/{target_context_id}/history", - params=filter_dict( - {"page_token": page_token, "limit": limit, "order": order, "order_by": order_by} - ), - ) - ) - .raise_for_status() - .json() - ) - - async def list_all_history( - self: "Context" | str, client: PlatformClient | None = None - ) -> AsyncIterator[ContextHistoryItem]: - result = await Context.list_history(self, client=client) - for item in result.items: - yield item - while result.has_more: - result = await Context.list_history(self, page_token=result.next_page_token, client=client) - for item in result.items: - yield item diff --git a/apps/adk-py/src/kagenti_adk/server/agent.py b/apps/adk-py/src/kagenti_adk/server/agent.py index 535066e2..b0c3186f 100644 --- a/apps/adk-py/src/kagenti_adk/server/agent.py +++ b/apps/adk-py/src/kagenti_adk/server/agent.py @@ -51,13 +51,12 @@ from kagenti_adk.server.context import RunContext from kagenti_adk.server.dependencies import Dependency, Depends, extract_dependencies from kagenti_adk.server.exceptions import InvalidYieldError -from kagenti_adk.server.store.context_store import ContextStore from kagenti_adk.server.utils import cancel_task, merge_messages from kagenti_adk.types import A2ASecurity, JsonPatch from kagenti_adk.util.logging import logger AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]] -AgentFunctionFactory: TypeAlias = Callable[[RequestContext, ContextStore], AbstractAsyncContextManager[AgentFunction]] +AgentFunctionFactory: TypeAlias = Callable[[RequestContext, TaskStore], AbstractAsyncContextManager[AgentFunction]] OriginalFnType = TypeVar("OriginalFnType", bound=Callable[..., Any]) @@ -356,7 +355,7 @@ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: class AgentRun: - def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callable[[], None] | None = None) -> None: + def __init__(self, agent: Agent, task_store: TaskStore, on_finish: Callable[[], None] | None = None) -> None: self._agent: Agent = agent self._task: asyncio.Task[None] | None = None self.last_invocation: datetime = datetime.now() @@ -364,7 +363,7 @@ def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callabl self._run_context: RunContext | None = None self._request_context: RequestContext | None = None self._task_updater: TaskUpdater | None = None - self._context_store: ContextStore = context_store + self._task_store: TaskStore = task_store self._lock: asyncio.Lock = asyncio.Lock() self._on_finish: Callable[[], None] | None = on_finish self._working: bool = False @@ -403,14 +402,13 @@ async def start(self, request_context: RequestContext, event_queue: EventQueue): raise RuntimeError("Attempting to start a run that is already executing or done") task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message assert task_id and context_id and message - context_store = await self._context_store.create(context_id) self._run_context = RunContext( configuration=request_context.configuration, context_id=context_id, task_id=task_id, current_task=request_context.current_task, related_tasks=request_context.related_tasks, - _store=context_store, + _task_store=self._task_store, ) self._request_context = request_context self._task_updater = TaskUpdater(event_queue, task_id, context_id) @@ -601,14 +599,12 @@ def __init__( self, agent: Agent, queue_manager: QueueManager, - context_store: ContextStore, task_timeout: timedelta, task_store: TaskStore, ) -> None: self._agent: Agent = agent self._running_tasks: dict[str, AgentRun] = {} self._scheduled_cleanups: dict[str, asyncio.Task[None]] = {} - self._context_store: ContextStore = context_store self._task_timeout: timedelta = task_timeout self._task_store: TaskStore = task_store @@ -620,7 +616,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non agent_run: AgentRun | None = None try: if not context.current_task: - agent_run = AgentRun(self._agent, self._context_store, lambda: self._handle_finish(task_id)) + agent_run = AgentRun(self._agent, self._task_store, lambda: self._handle_finish(task_id)) self._running_tasks[task_id] = agent_run await self._schedule_run_cleanup(request_context=context) await agent_run.start(request_context=context, event_queue=event_queue) diff --git a/apps/adk-py/src/kagenti_adk/server/app.py b/apps/adk-py/src/kagenti_adk/server/app.py index 3c5e0bc4..a3e08ec6 100644 --- a/apps/adk-py/src/kagenti_adk/server/app.py +++ b/apps/adk-py/src/kagenti_adk/server/app.py @@ -25,8 +25,6 @@ from kagenti_adk.a2a.extensions import BaseExtensionServer from kagenti_adk.server.agent import Agent, Executor from kagenti_adk.server.constants import DEFAULT_IMPLICIT_EXTENSIONS -from kagenti_adk.server.store.context_store import ContextStore -from kagenti_adk.server.store.memory_context_store import InMemoryContextStore from kagenti_adk.types import SdkAuthenticationBackend @@ -34,7 +32,6 @@ def create_app( agent: Agent, url: str, task_store: TaskStore | None = None, - context_store: ContextStore | None = None, implicit_extensions: dict[str, BaseExtensionServer] = DEFAULT_IMPLICIT_EXTENSIONS, required_extensions: set[str] | None = None, auth_backend: SdkAuthenticationBackend | None = None, @@ -49,12 +46,10 @@ def create_app( ) -> FastAPI: queue_manager = queue_manager or InMemoryQueueManager() task_store = task_store or InMemoryTaskStore() - context_store = context_store or InMemoryContextStore() http_handler = DefaultRequestHandler( agent_executor=Executor( agent, queue_manager, - context_store=context_store, task_timeout=task_timeout, task_store=task_store, ), @@ -73,7 +68,7 @@ def create_app( AgentInterface(url=url + "/jsonrpc/", protocol_binding="JSONRPC", protocol_version=protocol_version), ], implicit_extensions=implicit_extensions, - required_extensions=(required_extensions or set()) | context_store.required_extensions, + required_extensions=required_extensions or set(), ) jsonrpc_app = A2AFastAPIApplication(agent_card=agent.card, http_handler=http_handler).build( diff --git a/apps/adk-py/src/kagenti_adk/server/context.py b/apps/adk-py/src/kagenti_adk/server/context.py index 01e85eb2..bc0f7567 100644 --- a/apps/adk-py/src/kagenti_adk/server/context.py +++ b/apps/adk-py/src/kagenti_adk/server/context.py @@ -4,21 +4,17 @@ from __future__ import annotations from collections.abc import AsyncGenerator -from typing import Literal, overload -from uuid import UUID import janus +from a2a.server.tasks import TaskStore from a2a.types import ( Artifact, Message, Task, ) -from asgiref.sync import async_to_sync from pydantic import BaseModel, PrivateAttr from kagenti_adk.a2a.types import RunYield, RunYieldResume -from kagenti_adk.platform.context import ContextHistoryItem -from kagenti_adk.server.store.context_store import ContextStoreInstance class RunContextSettings(BaseModel): @@ -32,49 +28,25 @@ class RunContext(BaseModel, arbitrary_types_allowed=True): related_tasks: list[Task] | None = None strict: bool = False # TODO: explain strict mode - what yields will stop message etc. Use in match/case - _store: ContextStoreInstance + _task_store: TaskStore _yield_queue: janus.Queue[RunYield] = PrivateAttr(default_factory=janus.Queue) _yield_resume_queue: janus.Queue[RunYieldResume | Exception] = PrivateAttr(default_factory=janus.Queue) - def __init__(self, _store: ContextStoreInstance, **data): + def __init__(self, _task_store: TaskStore, **data): super().__init__(**data) - self._store = _store - - def _prepare_store_data(self, data: Message | Artifact) -> Message | Artifact: - if not self._store: - raise RuntimeError("Context store is not initialized") - if isinstance(data, Message): - msg = Message() - msg.CopyFrom(data) - msg.context_id = self.context_id - msg.task_id = self.task_id - return msg - return data - - async def store(self, data: Message | Artifact): - await self._store.store(self._prepare_store_data(data)) - - def store_sync(self, data: Message | Artifact): - async_to_sync(self._store.store)(self._prepare_store_data(data)) - - @overload - def load_history(self, load_history_items: Literal[False] = False) -> AsyncGenerator[Message | Artifact, None]: ... - - @overload - def load_history(self, load_history_items: Literal[True]) -> AsyncGenerator[ContextHistoryItem, None]: ... - - async def load_history( - self, load_history_items: bool = False - ) -> AsyncGenerator[ContextHistoryItem | Message | Artifact]: - if not self._store: - raise RuntimeError("Context store is not initialized") - async for item in self._store.load_history(load_history_items=load_history_items): - yield item - - async def delete_history_from_id(self, from_id: UUID) -> None: - if not self._store: - raise RuntimeError("Context store is not initialized") - await self._store.delete_history_from_id(from_id) + self._task_store = _task_store + + async def load_history(self) -> AsyncGenerator[Message | Artifact, None]: + """Load conversation history from the A2A TaskStore. + + Yields messages and artifacts from the current task's history. + """ + task = await self._task_store.get(self.task_id) + if task: + for msg in task.history: + yield msg + for artifact in task.artifacts: + yield artifact def yield_sync(self, value: RunYield) -> RunYieldResume: self._yield_queue.sync_q.put(value) diff --git a/apps/adk-py/src/kagenti_adk/server/server.py b/apps/adk-py/src/kagenti_adk/server/server.py index a3ea4029..9e226d8a 100644 --- a/apps/adk-py/src/kagenti_adk/server/server.py +++ b/apps/adk-py/src/kagenti_adk/server/server.py @@ -37,7 +37,6 @@ from kagenti_adk.server.agent import Agent from kagenti_adk.server.agent import agent as agent_decorator from kagenti_adk.server.constants import DEFAULT_IMPLICIT_EXTENSIONS -from kagenti_adk.server.store.context_store import ContextStore from kagenti_adk.server.telemetry import configure_telemetry as configure_telemetry_func from kagenti_adk.server.utils import cancel_task from kagenti_adk.types import SdkAuthenticationBackend @@ -49,7 +48,6 @@ class Server: def __init__(self) -> None: self._agent: Agent | None = None self.server: uvicorn.Server | None = None - self._context_store: ContextStore | None = None self._self_registration_client: PlatformClient | None = None self._self_registration_id: str | None = None self._provider_id: str | None = None @@ -73,7 +71,6 @@ async def serve( self_registration: bool = True, self_registration_id: str | None = None, task_store: TaskStore | None = None, - context_store: ContextStore | None = None, queue_manager: QueueManager | None = None, task_timeout: timedelta = timedelta(minutes=10), push_config_store: PushNotificationConfigStore | None = None, @@ -195,7 +192,6 @@ async def _lifespan_fn(app: FastAPI) -> AsyncGenerator[None, None]: lifespan=_lifespan_fn, implicit_extensions=implicit_extensions, task_store=task_store, - context_store=context_store, queue_manager=queue_manager, push_config_store=push_config_store, push_sender=push_sender, diff --git a/apps/adk-py/src/kagenti_adk/server/store/context_store.py b/apps/adk-py/src/kagenti_adk/server/store/context_store.py deleted file mode 100644 index 88e76c58..00000000 --- a/apps/adk-py/src/kagenti_adk/server/store/context_store.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - - -from __future__ import annotations - -import abc -from collections.abc import AsyncIterator -from typing import Protocol -from uuid import UUID - -from a2a.types import Artifact, Message - -from kagenti_adk.platform.context import ContextHistoryItem - -__all__ = [ - "ContextStore", - "ContextStoreInstance", -] - - -class ContextStoreInstance(Protocol): - def load_history( - self, load_history_items: bool = False - ) -> AsyncIterator[ContextHistoryItem | Message | Artifact]: ... - async def store(self, data: Message | Artifact) -> None: ... - async def delete_history_from_id(self, from_id: UUID) -> None: ... - - -class ContextStore(abc.ABC): - @property - def required_extensions(self) -> set[str]: - return set() - - @abc.abstractmethod - async def create(self, context_id: str) -> ContextStoreInstance: ... diff --git a/apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py b/apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py deleted file mode 100644 index 29342639..00000000 --- a/apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import AsyncIterator -from datetime import timedelta -from uuid import UUID - -from a2a.types import Artifact, Message -from cachetools import TTLCache -from google.protobuf.json_format import MessageToDict - -from kagenti_adk.platform.context import ContextHistoryItem -from kagenti_adk.server.store.context_store import ContextStore, ContextStoreInstance - - -class MemoryContextStoreInstance(ContextStoreInstance): - def __init__(self, context_id: str): - self.context_id = context_id - self._history: list[ContextHistoryItem] = [] - - async def load_history( - self, load_history_items: bool = False - ) -> AsyncIterator[ContextHistoryItem | Message | Artifact]: - for item in self._history.copy(): - if load_history_items: - yield item - else: - yield item.data - - async def store(self, data: Message | Artifact) -> None: - self._history.append(ContextHistoryItem(data=MessageToDict(data), context_id=self.context_id)) - - async def delete_history_from_id(self, from_id: UUID) -> None: - # Does not allow to delete from an artifact onwards - index = next( - (i for i, item in enumerate(self._history) if item.id == from_id), - None, - ) - if index is not None: - self._history = self._history[:index] - - -class InMemoryContextStore(ContextStore): - def __init__(self, max_contexts: int = 1000, context_ttl: timedelta = timedelta(hours=1)): - """ - Initialize in-memory context store with TTL cache. - - Args: - max_contexts: Maximum number of contexts to keep in memory - ttl_seconds: Time-to-live for context instances in seconds (default: 1 hour) - """ - self._instances: TTLCache[str, MemoryContextStoreInstance] = TTLCache( - maxsize=max_contexts, ttl=context_ttl.total_seconds() - ) - - async def create(self, context_id: str) -> ContextStoreInstance: - if context_id not in self._instances: - self._instances[context_id] = MemoryContextStoreInstance(context_id) - return self._instances[context_id] diff --git a/apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py b/apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py deleted file mode 100644 index d1b20d85..00000000 --- a/apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from uuid import UUID - -from a2a.types import Artifact, Message - -from kagenti_adk.a2a.extensions.services.platform import PlatformApiExtensionServer, PlatformApiExtensionSpec -from kagenti_adk.platform.context import Context, ContextHistoryItem -from kagenti_adk.server.store.context_store import ContextStore, ContextStoreInstance - - -class PlatformContextStore(ContextStore): - @property - def required_extensions(self) -> set[str]: - return {PlatformApiExtensionSpec.URI} - - async def create(self, context_id: str) -> ContextStoreInstance: - return PlatformContextStoreInstance(context_id=context_id) - - -class PlatformContextStoreInstance(ContextStoreInstance): - def __init__(self, context_id: str): - self._context_id = context_id - - @asynccontextmanager - async def client(self): - if not (ext := PlatformApiExtensionServer.current()): - raise RuntimeError("PlatformApiExtensionServer is not initialized") - async with ext.use_client(): - yield - - async def load_history( - self, load_history_items: bool = False - ) -> AsyncIterator[ContextHistoryItem | Message | Artifact]: - async with self.client(): - async for history_item in Context.list_all_history(self._context_id): - if load_history_items: - yield history_item - else: - yield history_item.data - - async def store(self, data: Message | Artifact) -> None: - async with self.client(): - await Context.add_history_item(self._context_id, data=data) - - async def delete_history_from_id(self, from_id: UUID) -> None: - async with self.client(): - await Context.delete_history_from_id(self._context_id, from_id=from_id) diff --git a/apps/adk-py/tests/e2e/conftest.py b/apps/adk-py/tests/e2e/conftest.py index 9b0d998d..abbefc2e 100644 --- a/apps/adk-py/tests/e2e/conftest.py +++ b/apps/adk-py/tests/e2e/conftest.py @@ -30,7 +30,6 @@ from kagenti_adk.a2a.types import AgentArtifact, AgentMessage, ArtifactChunk, InputRequired, RunYield, RunYieldResume from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.context_store import ContextStore pytestmark = pytest.mark.e2e @@ -52,7 +51,6 @@ def make_extension_context(extensions: list[str] | None = None) -> ClientCallCon async def run_server( server: Server, port: int, - context_store: ContextStore | None = None, task_timeout: timedelta | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: @@ -61,7 +59,6 @@ async def run_server( server.run, self_registration=False, port=port, - context_store=context_store, task_timeout=task_timeout or timedelta(minutes=5), ) ) @@ -92,7 +89,6 @@ def create_server_with_agent(): @asynccontextmanager async def _create_server( agent_fn, - context_store: ContextStore | None = None, task_timeout: timedelta | None = None, ) -> AsyncIterator[tuple[Server, Client]]: server = Server() @@ -100,7 +96,6 @@ async def _create_server( async with run_server( server, get_free_port(), - context_store=context_store, task_timeout=task_timeout, ) as (server, client): yield server, client diff --git a/apps/adk-py/tests/e2e/test_history.py b/apps/adk-py/tests/e2e/test_history.py index 786e8b4a..e67254b9 100644 --- a/apps/adk-py/tests/e2e/test_history.py +++ b/apps/adk-py/tests/e2e/test_history.py @@ -8,16 +8,15 @@ import pytest from a2a.client import Client, ClientEvent, create_text_message_object from a2a.types import ( + Artifact, Message, - Role, SendMessageRequest, Task, ) -from kagenti_adk.a2a.types import RunYield +from kagenti_adk.a2a.types import AgentMessage, RunYield from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.memory_context_store import InMemoryContextStore pytestmark = pytest.mark.e2e @@ -44,86 +43,31 @@ async def send_message_get_response( @pytest.fixture -async def history_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - """Agent that tests context.store.load_history() functionality.""" - context_store = InMemoryContextStore() - - async def history_agent(input: Message, context: RunContext) -> AsyncGenerator[RunYield, None]: - await context.store(input) - async for message in context.load_history(): - message.role = Role.ROLE_AGENT - yield message - await context.store(message) - - async with create_server_with_agent(history_agent, context_store=context_store) as (server, client): +async def history_reader_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + """Agent that reads history from the task store via RunContext.load_history().""" + + async def history_reader(input: Message, context: RunContext) -> AsyncGenerator[RunYield, None]: + # Load history from the task store (will contain messages from previous interactions in same task) + history_items: list[str] = [] + async for item in context.load_history(): + if isinstance(item, Message) and item.parts: + history_items.append(item.parts[0].text) + elif isinstance(item, Artifact) and item.parts: + history_items.append(f"artifact:{item.parts[0].text}") + + # Echo back what we found in history plus the current input + if history_items: + yield AgentMessage(text=f"history={','.join(history_items)}") + yield AgentMessage(text=f"input={input.parts[0].text}") + + async with create_server_with_agent(history_reader) as (server, client): yield server, client -@pytest.fixture -async def history_deleting_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - """Agent that tests context.store.load_history() functionality.""" - context_store = InMemoryContextStore() - - async def history_agent(input: Message, context: RunContext) -> AsyncGenerator[RunYield, None]: - await context.store(input) - n_messages = 0 - async for message in context.load_history(load_history_items=True): - n_messages += 1 - if n_messages == 1: - delete_id = message.id - if n_messages > 3: - # pyrefly: ignore [unbound-name] - await context.delete_history_from_id(delete_id) - break - - async for message in context.load_history(): - message.role = Role.ROLE_AGENT - yield message - - async with create_server_with_agent(history_agent, context_store=context_store) as (server, client): - yield server, client - - -async def test_agent_history(history_agent): - """Test that history starts empty.""" - _, client = history_agent - - agent_messages, context_id = await send_message_get_response(client, "first message") - assert agent_messages == ["first message"] - - agent_messages, context_id = await send_message_get_response(client, "second message", context_id=context_id) - assert agent_messages == ["first message", "first message", "second message"] - - agent_messages, context_id = await send_message_get_response(client, "third message", context_id=context_id) - assert agent_messages == [ - # first run - "first message", - # second run - "first message", - "second message", - # third run - "first message", - "first message", - "second message", - "third message", - ] - - -async def test_agent_deleting_history(history_deleting_agent): - """Test that history starts empty.""" - _, client = history_deleting_agent - - agent_messages, context_id = await send_message_get_response(client, "first message") - assert agent_messages == ["first message"] - - agent_messages, context_id = await send_message_get_response(client, "second message", context_id=context_id) - assert agent_messages == ["first message", "second message"] - - agent_messages, context_id = await send_message_get_response(client, "third message", context_id=context_id) - assert agent_messages == ["first message", "second message", "third message"] - - agent_messages, context_id = await send_message_get_response(client, "delete message", context_id=context_id) - assert agent_messages == [] +async def test_load_history_from_task_store(history_reader_agent): + """Test that RunContext.load_history() reads from the A2A task store.""" + _, client = history_reader_agent - agent_messages, context_id = await send_message_get_response(client, "first message") - assert agent_messages == ["first message"] + # First message — no history yet + agent_messages, context_id = await send_message_get_response(client, "hello") + assert any("input=hello" in msg for msg in agent_messages) diff --git a/apps/adk-server/src/adk_server/api/routes/contexts.py b/apps/adk-server/src/adk_server/api/routes/contexts.py index 5a501a8d..4e2d95eb 100644 --- a/apps/adk-server/src/adk_server/api/routes/contexts.py +++ b/apps/adk-server/src/adk_server/api/routes/contexts.py @@ -14,13 +14,11 @@ from adk_server.api.dependencies import ( ConfigurationDependency, ContextServiceDependency, - RequiresContextPermissionsPath, RequiresPermissions, ) -from adk_server.api.schema.common import EntityModel, PaginationQuery +from adk_server.api.schema.common import EntityModel from adk_server.api.schema.contexts import ( ContextCreateRequest, - ContextHistoryItemCreateRequest, ContextListQuery, ContextPatchMetadataRequest, ContextTokenCreateRequest, @@ -28,7 +26,7 @@ ContextUpdateRequest, ) from adk_server.domain.models.common import PaginatedResult -from adk_server.domain.models.context import Context, ContextHistoryItem +from adk_server.domain.models.context import Context from adk_server.domain.models.permissions import AuthorizedUser, Permissions logger = logging.getLogger(__name__) @@ -136,33 +134,3 @@ async def generate_context_token( configuration=configuration, ) return ContextTokenResponse(token=token, expires_at=expires_at) - - -@router.post("/{context_id}/history", status_code=status.HTTP_201_CREATED) -async def add_context_history_item( - context_id: UUID, - history_item_data: ContextHistoryItemCreateRequest, - context_service: ContextServiceDependency, - user: Annotated[AuthorizedUser, Depends(RequiresContextPermissionsPath(context_data={"write"}))], -) -> None: - await context_service.add_history_item(context_id=context_id, data=history_item_data.root, user=user.user) - - -@router.get("/{context_id}/history") -async def list_context_history( - context_id: UUID, - context_service: ContextServiceDependency, - user: Annotated[AuthorizedUser, Depends(RequiresContextPermissionsPath(context_data={"read"}))], - pagination: Annotated[PaginationQuery, Query()], -) -> PaginatedResult[ContextHistoryItem]: - return await context_service.list_history(context_id=context_id, user=user.user, pagination=pagination) - - -@router.delete("/{context_id}/history", status_code=status.HTTP_204_NO_CONTENT) -async def delete_context_history_from_id( - context_id: UUID, - from_id: Annotated[UUID, Query()], - context_service: ContextServiceDependency, - user: Annotated[AuthorizedUser, Depends(RequiresContextPermissionsPath(context_data={"read", "write"}))], -) -> None: - await context_service.delete_history_from_id(context_id=context_id, from_id=from_id, user=user.user) diff --git a/apps/adk-server/src/adk_server/api/schema/contexts.py b/apps/adk-server/src/adk_server/api/schema/contexts.py index 27c4c483..9a3a599d 100644 --- a/apps/adk-server/src/adk_server/api/schema/contexts.py +++ b/apps/adk-server/src/adk_server/api/schema/contexts.py @@ -5,11 +5,10 @@ from typing import Literal from uuid import UUID -from pydantic import AwareDatetime, BaseModel, Field, RootModel, field_validator +from pydantic import AwareDatetime, BaseModel, Field, field_validator from adk_server.api.schema.common import PaginationQuery from adk_server.domain.models.common import Metadata, MetadataPatch -from adk_server.domain.models.context import ContextHistoryItemData class ContextCreateRequest(BaseModel): @@ -88,5 +87,3 @@ class ContextTokenResponse(BaseModel): expires_at: AwareDatetime | None -class ContextHistoryItemCreateRequest(RootModel[ContextHistoryItemData]): - root: ContextHistoryItemData diff --git a/apps/adk-server/src/adk_server/domain/repositories/context.py b/apps/adk-server/src/adk_server/domain/repositories/context.py index df818ef6..738849f9 100644 --- a/apps/adk-server/src/adk_server/domain/repositories/context.py +++ b/apps/adk-server/src/adk_server/domain/repositories/context.py @@ -9,7 +9,7 @@ from uuid import UUID from adk_server.domain.models.common import PaginatedResult -from adk_server.domain.models.context import Context, ContextHistoryItem, TitleGenerationState +from adk_server.domain.models.context import Context class IContextRepository(Protocol): @@ -33,17 +33,3 @@ async def get(self, *, context_id: UUID, user_id: UUID | None = None) -> Context async def update(self, *, context: Context) -> None: ... async def delete(self, *, context_id: UUID, user_id: UUID | None = None) -> int: ... async def update_last_active(self, *, context_id: UUID) -> None: ... - async def update_title( - self, *, context_id: UUID, title: str | None = None, generation_state: TitleGenerationState - ) -> None: ... - async def add_history_item(self, *, context_id: UUID, history_item: ContextHistoryItem) -> None: ... - async def list_history( - self, - *, - context_id: UUID, - page_token: UUID | None = None, - limit: int = 20, - order_by: str = "created_at", - order="desc", - ) -> PaginatedResult[ContextHistoryItem]: ... - async def delete_history_from_id(self, *, context_id: UUID, from_id: UUID) -> int: ... diff --git a/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py b/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py index 119ffbf9..a975b517 100644 --- a/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py +++ b/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py @@ -5,16 +5,14 @@ from collections.abc import AsyncIterator from datetime import datetime -from uuid import UUID, uuid4 +from uuid import UUID from kink import inject -from pydantic import TypeAdapter from sqlalchemy import ( JSON, Column, DateTime, ForeignKey, - Index, Row, Table, delete, @@ -24,8 +22,8 @@ from sqlalchemy import UUID as SQL_UUID from sqlalchemy.ext.asyncio import AsyncConnection -from adk_server.domain.models.common import Metadata, PaginatedResult -from adk_server.domain.models.context import Context, ContextHistoryItem, TitleGenerationState +from adk_server.domain.models.common import PaginatedResult +from adk_server.domain.models.context import Context from adk_server.domain.repositories.context import IContextRepository from adk_server.exceptions import EntityNotFoundError from adk_server.infrastructure.persistence.repositories.db_metadata import metadata @@ -44,16 +42,6 @@ Column("metadata", JSON, nullable=True), ) -context_history_table = Table( - "context_history", - metadata, - Column("id", SQL_UUID, primary_key=True), - Column("context_id", ForeignKey("contexts.id", ondelete="CASCADE"), nullable=False), - Column("created_at", DateTime(timezone=True), nullable=False), - Column("data", JSON, nullable=False), - Index("idx_context_history_context_id", "context_id"), -) - @inject class SqlAlchemyContextRepository(IContextRepository): @@ -92,12 +80,8 @@ async def list_paginated( query = query.where(contexts_table.c.provider_id == provider_id) if last_active_before: query = query.where(contexts_table.c.last_active_at < last_active_before) - if not include_empty: - # Use EXISTS subquery to find contexts that have at least one history record - subquery = select(context_history_table.c.context_id).where( - context_history_table.c.context_id == contexts_table.c.id - ) - query = query.where(subquery.exists()) + # NOTE: include_empty is accepted but no longer filtered — context_history table has been removed. + # All contexts are now returned regardless of whether they have task history. result = await cursor_paginate( connection=self._connection, @@ -163,88 +147,6 @@ async def update_last_active(self, *, context_id: UUID) -> None: query = update(contexts_table).where(contexts_table.c.id == context_id).values(last_active_at=utc_now()) await self._connection.execute(query) - async def update_title( - self, *, context_id: UUID, title: str | None = None, generation_state: TitleGenerationState - ) -> None: - # validate length before saving to database - if title: - _ = TypeAdapter(Metadata).validate_python({"title": title}) - context = await self.get(context_id=context_id) - query = ( - contexts_table.update() - .where(contexts_table.c.id == context_id) - .values( - metadata=(context.metadata or {}) - | ({"title": title} if title else {}) - | {"title_generation_state": generation_state} - ) - ) - await self._connection.execute(query) - - async def add_history_item(self, *, context_id: UUID, history_item: ContextHistoryItem) -> None: - query = context_history_table.insert().values( - id=uuid4(), - context_id=history_item.context_id, - created_at=history_item.created_at, - data=history_item.data, - ) - await self._connection.execute(query) - - async def list_history( - self, - *, - context_id: UUID, - page_token: UUID | None = None, - limit: int = 20, - order_by: str = "created_at", - order="desc", - ) -> PaginatedResult[ContextHistoryItem]: - query = context_history_table.select().where(context_history_table.c.context_id == context_id) - result = await cursor_paginate( - connection=self._connection, - query=query, - after_cursor=page_token, - id_column=context_history_table.c.id, - order_column=getattr(context_history_table.c, order_by), - order=order, - limit=limit, - ) - return PaginatedResult( - items=[self._row_to_context_history_item(item) for item in result.items], - total_count=result.total_count, - has_more=result.has_more, - ) - - async def delete_history_from_id(self, *, context_id: UUID, from_id: UUID) -> int: - """Delete all history items from a specific item onwards (inclusive) in given context""" - # First, get the created_at timestamp of the item to delete from - query_item = select(context_history_table.c.created_at).where( - context_history_table.c.context_id == context_id, - context_history_table.c.id == from_id, - ) - result = await self._connection.execute(query_item) - row = result.first() - if not row: - raise EntityNotFoundError("context_history_item", from_id) - - created_at = row[0] - - # Delete all history items from the specified item onwards (created_at >= the target item's created_at) - query = delete(context_history_table).where( - context_history_table.c.context_id == context_id, - context_history_table.c.created_at >= created_at, - ) - result = await self._connection.execute(query) - return result.rowcount - - def _row_to_context_history_item(self, row: Row) -> ContextHistoryItem: - return ContextHistoryItem( - id=row.id, - data=row.data, - context_id=row.context_id, - created_at=row.created_at, - ) - def _row_to_context(self, row: Row) -> Context: return Context( id=row.id, diff --git a/apps/adk-server/src/adk_server/jobs/procrastinate.py b/apps/adk-server/src/adk_server/jobs/procrastinate.py index 5a171c05..94e30efe 100644 --- a/apps/adk-server/src/adk_server/jobs/procrastinate.py +++ b/apps/adk-server/src/adk_server/jobs/procrastinate.py @@ -13,7 +13,6 @@ from adk_server.jobs.crons.connector import blueprint as connector_crons from adk_server.jobs.crons.model_provider import blueprint as model_provider_crons from adk_server.jobs.crons.provider import blueprint as provider_crons -from adk_server.jobs.tasks.context import blueprint as context_tasks from adk_server.jobs.tasks.file import blueprint as file_tasks logger = logging.getLogger(__name__) @@ -54,7 +53,6 @@ def exit_app_on_db_error(*_args, **_kwargs): worker_defaults=WorkerOptions(install_signal_handlers=False), ) app.add_tasks_from(blueprint=file_tasks, namespace="text_extraction") - app.add_tasks_from(blueprint=context_tasks, namespace="context_tasks") app.add_tasks_from(blueprint=provider_crons, namespace="cron_provider") app.add_tasks_from(blueprint=model_provider_crons, namespace="cron_model_provider") app.add_tasks_from(blueprint=cleanup_crons, namespace="cron_cleanup") diff --git a/apps/adk-server/src/adk_server/jobs/queues.py b/apps/adk-server/src/adk_server/jobs/queues.py index cabf8600..2965a8e5 100644 --- a/apps/adk-server/src/adk_server/jobs/queues.py +++ b/apps/adk-server/src/adk_server/jobs/queues.py @@ -13,7 +13,6 @@ class Queues(StrEnum): CRON_MODEL_PROVIDER = "cron:model_provider" CRON_CONNECTOR = "cron:connector" # tasks - GENERATE_CONVERSATION_TITLE = "generate_conversation_title" TEXT_EXTRACTION = "text_extraction" TOOLKIT_DELETION = "toolkit_deletion" diff --git a/apps/adk-server/src/adk_server/jobs/tasks/context.py b/apps/adk-server/src/adk_server/jobs/tasks/context.py index 9d8fb7cd..7abe8ca6 100644 --- a/apps/adk-server/src/adk_server/jobs/tasks/context.py +++ b/apps/adk-server/src/adk_server/jobs/tasks/context.py @@ -1,20 +1,2 @@ # Copyright 2026 © IBM Corp. # SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from uuid import UUID - -from kink import inject -from procrastinate import Blueprint - -from adk_server.jobs.queues import Queues -from adk_server.service_layer.services.contexts import ContextService - -blueprint = Blueprint() - - -@blueprint.task(queue=str(Queues.GENERATE_CONVERSATION_TITLE)) -@inject -async def generate_conversation_title(context_id: str, context_service: ContextService): - await context_service.generate_conversation_title(context_id=UUID(context_id)) diff --git a/apps/adk-server/src/adk_server/run_workers.py b/apps/adk-server/src/adk_server/run_workers.py index ee8b08f0..d1b3d8d7 100644 --- a/apps/adk-server/src/adk_server/run_workers.py +++ b/apps/adk-server/src/adk_server/run_workers.py @@ -29,11 +29,6 @@ async def run_workers(app: procrastinate.App): ], concurrency=10, ), - WorkerOptions( - name="generate_conversation_title_worker", - queues=[str(Queues.GENERATE_CONVERSATION_TITLE)], - concurrency=10, - ), WorkerOptions(name="text_extraction_worker", queues=[str(Queues.TEXT_EXTRACTION)], concurrency=5), ] diff --git a/apps/adk-server/src/adk_server/service_layer/services/a2a.py b/apps/adk-server/src/adk_server/service_layer/services/a2a.py index 98f2f3c2..dbfef5ce 100644 --- a/apps/adk-server/src/adk_server/service_layer/services/a2a.py +++ b/apps/adk-server/src/adk_server/service_layer/services/a2a.py @@ -380,7 +380,8 @@ async def on_list_tasks( params: ListTasksRequest, context: ServerCallContext, ) -> ListTasksResponse: - raise NotImplementedError("This is not supported by the client transport yet") + async with self._client_transport(context) as transport: + return await transport.list_tasks(params, context=self._forward_context(context)) @inject diff --git a/apps/adk-server/src/adk_server/service_layer/services/contexts.py b/apps/adk-server/src/adk_server/service_layer/services/contexts.py index 25c5a8ee..3c16eb8a 100644 --- a/apps/adk-server/src/adk_server/service_layer/services/contexts.py +++ b/apps/adk-server/src/adk_server/service_layer/services/contexts.py @@ -4,10 +4,8 @@ from __future__ import annotations import logging -from collections.abc import Sequence from contextlib import suppress from datetime import timedelta -from typing import Any from uuid import UUID from fastapi import status @@ -15,19 +13,12 @@ from pydantic import TypeAdapter from adk_server.api.schema.common import PaginationQuery -from adk_server.api.schema.openai import ChatCompletionRequest from adk_server.configuration import Configuration from adk_server.domain.models.common import Metadata, MetadataPatch, PaginatedResult -from adk_server.domain.models.context import ( - Context, - ContextHistoryItem, - ContextHistoryItemData, - TitleGenerationState, -) +from adk_server.domain.models.context import Context from adk_server.domain.models.user import User from adk_server.domain.repositories.file import IObjectStorageRepository from adk_server.exceptions import EntityNotFoundError, PlatformError -from adk_server.service_layer.services.model_providers import ModelProviderService from adk_server.service_layer.unit_of_work import IUnitOfWorkFactory from adk_server.utils.utils import filter_dict, utc_now @@ -41,13 +32,11 @@ def __init__( uow: IUnitOfWorkFactory, configuration: Configuration, object_storage: IObjectStorageRepository, - model_provider_service: ModelProviderService, ): self._uow = uow self._object_storage = object_storage self._configuration = configuration self._expire_resources_after = timedelta(days=configuration.context.resources_expire_after_days) - self._model_provider_service = model_provider_service async def create(self, *, user: User, metadata: Metadata, provider_id: UUID | None = None) -> Context: context = Context(created_by=user.id, metadata=metadata, provider_id=provider_id) @@ -160,125 +149,3 @@ async def update_last_active(self, *, context_id: UUID) -> None: await uow.contexts.update_last_active(context_id=context_id) await uow.commit() - def _extract_content_for_title(self, msg: dict[str, Any]) -> tuple[str, str | None, Sequence[dict[str, Any]]]: - title_hint: str | None = None - text_parts: list[str] = [] - files: list[dict[str, Any]] = [] - for part in msg.get("parts", []): - if "text" in part: - text_parts.append(part["text"]) - elif "data" in part: - data = part["data"] - if isinstance(data, dict): - hint = data.get("title_hint") - if isinstance(hint, str) and hint and not title_hint: - title_hint = hint - elif "file" in part: - files.append(part["file"]) - - return "".join(text_parts), title_hint, files - - async def add_history_item(self, *, context_id: UUID, data: ContextHistoryItemData, user: User) -> None: - async with self._uow() as uow: - context = await uow.contexts.get(context_id=context_id, user_id=user.id) - await uow.contexts.add_history_item( - context_id=context_id, - history_item=ContextHistoryItem(context_id=context_id, data=data), - ) - - if data.get("role") == "ROLE_USER" and not (context.metadata or {}).get("title"): - from adk_server.jobs.tasks.context import generate_conversation_title as task - - # Use simple text extraction for the initial title placeholder - title = self._extract_content_for_title(data)[0] or "Untitled" - title = f"{title[:100]}..." if len(title) > 100 else title - - should_generate_title = self._configuration.generate_conversation_title.enabled - state = TitleGenerationState.PENDING if should_generate_title else TitleGenerationState.COMPLETED - await uow.contexts.update_title(context_id=context_id, title=title, generation_state=state) - - if should_generate_title: - await task.configure(queueing_lock=str(context_id)).defer_async(context_id=str(context_id)) - - await uow.commit() - - async def generate_conversation_title(self, *, context_id: UUID): - from jinja2 import Template - - async with self._uow() as uow: - msg = await uow.contexts.list_history(context_id=context_id, limit=1, order="desc", order_by="created_at") - system_config = await uow.configuration.get_system_configuration() - - model = self._configuration.generate_conversation_title.model - if model == "default": - if not system_config.default_llm_model: - logger.warning(f"Cannot generate title for context {context_id}: default LLM model not set.") - return - model = system_config.default_llm_model - - if not msg.items: - logger.warning(f"Cannot generate title for context {context_id}: no history found.") - return - - raw_message = msg.items[0].data - text, title_hint, files = self._extract_content_for_title(raw_message) - if not text and not title_hint and not files: - logger.warning(f"Cannot generate title for context {context_id}: first message has no content.") - return - - try: - # Render the system prompt using Jinja2 - template = Template(self._configuration.generate_conversation_title.prompt) - prompt = template.render( - text=text, - titleHint=title_hint, - files=[{"name": f.get("name"), "mime_type": f.get("mime_type")} for f in files], - rawMessage=raw_message, - ) - resp = await self._model_provider_service.create_chat_completion( - request=ChatCompletionRequest( - model=model, - stream=False, - max_completion_tokens=100, - messages=[{"role": "user", "content": prompt}], - ) - ) - title = (resp.choices[0].message.content or "").strip().strip("\"'") - title = f"{title[:100]}..." if len(title) > 100 else title - if not title: - raise RuntimeError("Generated title is empty.") - async with self._uow() as uow: - await uow.contexts.update_title( - context_id=context_id, title=title, generation_state=TitleGenerationState.COMPLETED - ) - await uow.commit() - except Exception as e: - async with self._uow() as uow: - await uow.contexts.update_title( - context_id=context_id, title=None, generation_state=TitleGenerationState.FAILED - ) - await uow.commit() - logger.warning(f"Failed to generate title for context {context_id}: {e}") - raise e - - async def list_history( - self, *, context_id: UUID, user: User, pagination: PaginationQuery - ) -> PaginatedResult[ContextHistoryItem]: - async with self._uow() as uow: - await uow.contexts.get(context_id=context_id, user_id=user.id) - return await uow.contexts.list_history( - context_id=context_id, - limit=pagination.limit, - page_token=pagination.page_token, - order=pagination.order, - order_by=pagination.order_by, - ) - - async def delete_history_from_id(self, *, context_id: UUID, from_id: UUID, user: User) -> None: - """Delete all history items from a specific item onwards (inclusive)""" - async with self._uow() as uow: - # Verify user has access to this context - await uow.contexts.get(context_id=context_id, user_id=user.id) - # Delete history items from the specified ID onwards - await uow.contexts.delete_history_from_id(context_id=context_id, from_id=from_id) - await uow.commit() diff --git a/apps/adk-server/tests/e2e/agents/conftest.py b/apps/adk-server/tests/e2e/agents/conftest.py index b99ec8cb..10f925af 100644 --- a/apps/adk-server/tests/e2e/agents/conftest.py +++ b/apps/adk-server/tests/e2e/agents/conftest.py @@ -14,7 +14,6 @@ from kagenti_adk.platform import PlatformClient, Provider from kagenti_adk.platform.context import ContextToken from kagenti_adk.server import Server -from kagenti_adk.server.store.context_store import ContextStore from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed from tests.conftest import Configuration @@ -27,7 +26,6 @@ async def run_server( test_admin: tuple[str, str], a2a_client_factory: Callable[[AgentCard | dict[str, Any], ContextToken], AsyncIterator[Client]], context_token: ContextToken, - context_store: ContextStore | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: tg.create_task( @@ -35,7 +33,6 @@ async def run_server( server.run, port=port, self_registration_client_factory=lambda: PlatformClient(auth=test_admin), - context_store=context_store, ) ) @@ -66,7 +63,6 @@ def create_server_with_agent( async def _create_server( agent_fn, context_token: ContextToken, - context_store: ContextStore | None = None, ): server = Server() server.agent()(agent_fn) @@ -74,7 +70,6 @@ async def _create_server( server, free_port, a2a_client_factory=a2a_client_factory, - context_store=context_store, context_token=context_token, test_admin=test_admin, ) as (server, client): diff --git a/apps/adk-server/tests/e2e/agents/test_context_store.py b/apps/adk-server/tests/e2e/agents/test_context_store.py deleted file mode 100644 index 44357904..00000000 --- a/apps/adk-server/tests/e2e/agents/test_context_store.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import AsyncGenerator, AsyncIterator - -import pytest -from a2a.client import Client, ClientEvent, create_text_message_object -from a2a.types import SendMessageRequest, Message, Role, Task -from kagenti_adk.a2a.extensions import PlatformApiExtensionClient, PlatformApiExtensionSpec -from kagenti_adk.a2a.types import RunYield -from kagenti_adk.platform.context import Context, ContextPermissions, ContextToken, Permissions -from kagenti_adk.server import Server -from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore - -pytestmark = pytest.mark.e2e - - -async def get_final_task_from_stream(stream: AsyncIterator[ClientEvent | Message]) -> Task | None: - final_task = None - async for event in stream: - match event: - case (_, task): - final_task = task - return final_task - - -@pytest.fixture -async def history_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - """Agent that tests context.store.load_history() functionality.""" - - async def history_agent(input: Message, context: RunContext) -> AsyncGenerator[RunYield]: - input.metadata = {"test": "metadata"} - await context.store(input) - async for message in context.load_history(): - message.role = Role.ROLE_AGENT - assert message.metadata == {"test": "metadata"} - yield message - await context.store(message) - - context = await Context.create() - token = await context.generate_token(grant_global_permissions=Permissions(a2a_proxy={"*"})) - async with create_server_with_agent( - history_agent, - context_token=token, - context_store=PlatformContextStore(), - ) as (server, client): - yield server, client - - -def create_message(token: ContextToken, content: str) -> Message: - api_extension_client = PlatformApiExtensionClient(PlatformApiExtensionSpec()) - message = create_text_message_object(content=content) - message.metadata = api_extension_client.api_auth_metadata(auth_token=token.token, expires_at=token.expires_at) - message.context_id = token.context_id - return message - - -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_agent_history(history_agent, subtests): - _, client = history_agent - - with subtests.test("history repeats itself"): - context1 = await Context.create() - token = await context1.generate_token( - grant_context_permissions=ContextPermissions(context_data={"*"}), - grant_global_permissions=Permissions(a2a_proxy={"*"}), - ) - - final_task = await get_final_task_from_stream(client.send_message(SendMessageRequest(message=create_message(token, "first message")))) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert all(msg.metadata == {"test": "metadata"} for msg in final_task.history) - assert agent_messages == ["first message"] - - final_task = await get_final_task_from_stream(client.send_message(SendMessageRequest(message=create_message(token, "second message")))) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert all(msg.metadata == {"test": "metadata"} for msg in final_task.history) - assert agent_messages == ["first message", "first message", "second message"] - - final_task = await get_final_task_from_stream(client.send_message(SendMessageRequest(message=create_message(token, "third message")))) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert all(msg.metadata == {"test": "metadata"} for msg in final_task.history) - assert agent_messages == [ - # first run - "first message", - # second run - "first message", - "second message", - # third run - "first message", - "first message", - "second message", - "third message", - ] - - context1_history = await Context.list_history(context1.id) - assert context1_history.total_count == 14 - - with subtests.test("other context id does not mix history"): - context2 = await Context.create() - token = await context2.generate_token( - grant_context_permissions=ContextPermissions(context_data={"*"}), - grant_global_permissions=Permissions(a2a_proxy={"*"}), - ) - final_task = await get_final_task_from_stream(client.send_message(SendMessageRequest(message=create_message(token, "first message")))) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert agent_messages == ["first message"] - - context1_history = await Context.list_history(context1.id) - assert context1_history.total_count == 14 - - context2_history = await Context.list_history(context2.id) - assert context2_history.total_count == 2 diff --git a/apps/adk-server/tests/e2e/routes/test_contexts.py b/apps/adk-server/tests/e2e/routes/test_contexts.py index b76ee9b9..874a5675 100644 --- a/apps/adk-server/tests/e2e/routes/test_contexts.py +++ b/apps/adk-server/tests/e2e/routes/test_contexts.py @@ -6,7 +6,6 @@ import uuid import pytest -from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.platform.context import Context from httpx import HTTPStatusError @@ -70,102 +69,6 @@ async def test_context_pagination(subtests): assert len(response.items) == 5 # Should return all contexts -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_context_history_pagination(subtests): - """Test cursor-based pagination for context history endpoint.""" - - # Create a context for testing - context = await Context.create() - - # Create more than 40 history items (default page size) to test pagination - num_items = 45 - - with subtests.test("add multiple history items"): - for i in range(num_items): - message = AgentMessage(text=f"Test message {i}") - await context.add_history_item(data=message) - - with subtests.test("test default pagination (first page)"): - response = await Context.list_history(context.id) - assert len(response.items) == 40 # Default page size - assert response.has_more is True - assert response.next_page_token is not None - - # Verify items are ordered by created_at desc (newest first) - created_ats = [item.created_at for item in response.items] - assert created_ats == sorted(created_ats, reverse=False) - - with subtests.test("test pagination with custom limit"): - response = await Context.list_history(context.id, limit=10) - assert len(response.items) == 10 - assert response.has_more is True - assert response.next_page_token is not None - - with subtests.test("test cursor-based pagination"): - # Get first page with limit 20 - first_page = await Context.list_history(context.id, limit=20) - assert len(first_page.items) == 20 - assert first_page.has_more is True - - # Get second page using next_page_token as cursor - second_page = await Context.list_history(context.id, limit=20, page_token=first_page.next_page_token) - assert len(second_page.items) == 20 - assert second_page.has_more is True - - # Get third page - third_page = await Context.list_history(context.id, limit=20, page_token=second_page.next_page_token) - assert len(third_page.items) == 5 # Remaining items - assert third_page.has_more is False - - # Verify no duplicate items across pages - all_items = first_page.items + second_page.items + third_page.items - all_ids = [item.id for item in all_items if hasattr(item, "id")] - assert len(all_ids) == len(set(all_ids)) # No duplicates - - with subtests.test("test ascending order"): - response = await Context.list_history(context.id, order="asc", limit=5) - created_ats = [item.created_at for item in response.items] - assert created_ats == sorted(created_ats) # Should be ascending - - with subtests.test("test list_all_history method"): - # Test the list_all_history method that automatically iterates through all pages - all_items = [] - async for item in Context.list_all_history(context.id): - all_items.append(item) - - assert len(all_items) == num_items - - # Verify chronological order (oldest first since it yields in order) - created_ats = [item.created_at for item in all_items] - # Note: list_all_history should maintain the order from list_history (desc by default) - # but iterate through all pages - - -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_context_empty_filtering(subtests): - """Test filtering contexts based on whether they have history records.""" - - with subtests.test("create contexts with and without history"): - # Create empty context (no history) - empty_context = await Context.create() - - # Create context with history - context_with_history = await Context.create() - message = AgentMessage(text="Test message") - await context_with_history.add_history_item(data=message) - - with subtests.test("include_empty=True returns all contexts"): - response = await Context.list(include_empty=True) - assert len(response.items) == 2 # Should include both contexts - - with subtests.test("include_empty=False returns only contexts with history"): - response = await Context.list(include_empty=False) - context_ids = [ctx.id for ctx in response.items] - assert len(context_ids) == 1 - assert context_with_history.id in context_ids - assert empty_context.id not in context_ids - - @pytest.mark.usefixtures("clean_up", "setup_platform_client") async def test_context_update_and_patch(subtests): """Test updating and patching context metadata.""" @@ -255,42 +158,3 @@ async def test_context_provider_filtering(subtests): assert fetched_context.provider_id == provider1.id -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_context_delete_context_history_from_id(subtests): - """Test deleting context history from a specific item ID onwards.""" - - context = None - history_items = [] - n_messages = 3 - - with subtests.test("create context and add multiple history items"): - context = await Context.create() - for i in range(n_messages): - message = AgentMessage(text=f"Test message {i}") - await context.add_history_item(data=message) - - history = await context.list_history(limit=50) - history_items = history.items - assert len(history.items) == n_messages - - with subtests.test("delete history from a middle item onwards"): - await context.delete_history_from_id(from_id=history_items[1].id) - - remaining_history = await context.list_history(limit=50) - remaining_ids = [item.id for item in remaining_history.items] - assert len(remaining_history.items) == 1 - assert history_items[0].id in remaining_ids - assert history_items[1].id not in remaining_ids - assert history_items[2].id not in remaining_ids - - with subtests.test("delete with nonexistent item_id raises error"): - nonexistent_id = uuid.uuid4() - with pytest.raises(HTTPStatusError) as exc_info: - await context.delete_history_from_id(from_id=nonexistent_id) - assert exc_info.value.response.status_code == 404 - - with subtests.test("delete from first item deletes all"): - await context.delete_history_from_id(from_id=remaining_ids[0]) - # await context.delete_history_from_id(from_id=remaining_ids[0]) - remaining_history = await context.list_history(limit=50) - assert len(remaining_history.items) == 0 diff --git a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts index 9083bd49..0d4e5048 100644 --- a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts +++ b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts @@ -8,6 +8,20 @@ import { agentCardSchema, streamResponseSchema } from '@kagenti/adk'; import { EventSourceParserStream } from 'eventsource-parser/stream'; import { v4 as uuid } from 'uuid'; +export interface ListTasksParams { + contextId?: string; + status?: string; + pageSize?: number; + pageToken?: string; +} + +export interface ListTasksResponse { + tasks: Task[]; + nextPageToken?: string; + totalSize?: number; + pageSize?: number; +} + export interface A2AClient { getAgentCard(): Promise; sendMessageStream(params: { @@ -17,6 +31,7 @@ export interface A2AClient { }): AsyncIterable; getTask(params: { id: string }): Promise; cancelTask(params: { id: string }): Promise; + listTasks(params: ListTasksParams): Promise; } interface CreateClientParams { @@ -117,6 +132,10 @@ export function createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions async cancelTask(params) { return jsonRpcRequest('CancelTask', params) as Promise; }, + + async listTasks(params) { + return jsonRpcRequest('ListTasks', { ...params }) as Promise; + }, }; } diff --git a/apps/adk-ui/src/api/a2a/list-tasks.ts b/apps/adk-ui/src/api/a2a/list-tasks.ts new file mode 100644 index 00000000..25f09a0e --- /dev/null +++ b/apps/adk-ui/src/api/a2a/list-tasks.ts @@ -0,0 +1,79 @@ +/** + * Copyright 2026 © IBM Corp. + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Task } from '@kagenti/adk'; +import { v4 as uuid } from 'uuid'; + +import { ensureToken } from '#app/(auth)/rsc.tsx'; +import { runtimeConfig } from '#contexts/App/runtime-config.ts'; +import { getBaseUrl } from '#utils/api/getBaseUrl.ts'; + +export interface ListTasksParams { + contextId?: string; + status?: string; + pageSize?: number; + pageToken?: string; +} + +export interface ListTasksResponse { + tasks: Task[]; + nextPageToken?: string; + totalSize?: number; + pageSize?: number; +} + +/** + * Server-side function to fetch tasks from the A2A proxy via JSON-RPC. + * Used in React Server Components. + */ +export async function fetchTasksForContext( + providerId: string, + contextId: string, +): Promise { + try { + const baseUrl = getBaseUrl(); + const endpointUrl = `${baseUrl}/api/v1/a2a/${providerId}/`; + + const { isAuthEnabled } = runtimeConfig; + const headers: Record = { 'Content-Type': 'application/json' }; + + if (isAuthEnabled) { + const token = await ensureToken(); + if (token?.accessToken) { + headers['Authorization'] = `Bearer ${token.accessToken}`; + } + } + + const response = await fetch(endpointUrl, { + method: 'POST', + headers, + body: JSON.stringify({ + jsonrpc: '2.0', + id: uuid(), + method: 'ListTasks', + params: { + contextId, + }, + }), + }); + + if (!response.ok) { + console.error(`ListTasks request failed: ${response.status} ${response.statusText}`); + return undefined; + } + + const data = await response.json(); + + if (data.error) { + console.error('ListTasks error:', data.error); + return undefined; + } + + return data.result as ListTasksResponse; + } catch (error) { + console.error('Failed to fetch tasks:', error); + return undefined; + } +} diff --git a/apps/adk-ui/src/modules/history/utils.ts b/apps/adk-ui/src/modules/history/utils.ts index fa30814a..af0996ee 100644 --- a/apps/adk-ui/src/modules/history/utils.ts +++ b/apps/adk-ui/src/modules/history/utils.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Artifact, ContextHistory, Message } from '@kagenti/adk'; +import type { Artifact, ContextHistory, Message, Task } from '@kagenti/adk'; import { v4 as uuid } from 'uuid'; import { processMessageMetadata, processParts } from '#api/a2a/part-processors.ts'; @@ -95,3 +95,51 @@ export function convertHistoryToUIMessages(history: ContextHistory[]): UIMessage return messages; } + +/** + * Convert A2A Tasks (with their history and artifacts) into UI messages. + * Tasks are expected in chronological order (oldest first). + */ +export function convertTasksToUIMessages(tasks: Task[]): UIMessage[] { + const allMessages: UIMessage[] = []; + + for (const task of tasks) { + const taskId = task.id; + + // Process history messages + for (const msg of task.history ?? []) { + const uiMessage = processHistoryMessage(msg, taskId); + + const lastMessage = allMessages.at(-1); + const shouldGroup = lastMessage && lastMessage.role === uiMessage.role && lastMessage.taskId === uiMessage.taskId; + + if (shouldGroup) { + allMessages.splice(-1, 1, { + ...lastMessage, + parts: [...uiMessage.parts, ...lastMessage.parts], + }); + } else { + allMessages.push(uiMessage); + } + } + + // Process artifacts + for (const artifact of task.artifacts ?? []) { + const uiMessage = processHistoryArtifact(artifact, taskId); + + const lastMessage = allMessages.at(-1); + const shouldGroup = lastMessage && lastMessage.role === uiMessage.role && lastMessage.taskId === uiMessage.taskId; + + if (shouldGroup) { + allMessages.splice(-1, 1, { + ...lastMessage, + parts: [...uiMessage.parts, ...lastMessage.parts], + }); + } else { + allMessages.push(uiMessage); + } + } + } + + return allMessages; +} diff --git a/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx b/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx index 9e8a99dc..84db539e 100644 --- a/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx +++ b/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx @@ -3,70 +3,22 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type PropsWithChildren, useCallback, useEffect, useMemo } from 'react'; +import { type PropsWithChildren, useCallback, useMemo } from 'react'; -import { useFetchNextPage } from '#hooks/useFetchNextPage.ts'; import { useImmerWithGetter } from '#hooks/useImmerWithGetter.ts'; -import { convertHistoryToUIMessages } from '#modules/history/utils.ts'; +import { convertTasksToUIMessages } from '#modules/history/utils.ts'; import type { UIMessage } from '#modules/messages/types.ts'; -import { isAgentMessage } from '#modules/messages/utils.ts'; -import { LIST_CONTEXT_HISTORY_DEFAULT_QUERY } from '#modules/platform-context/api/constants.ts'; -import { useListContextHistory } from '#modules/platform-context/api/queries/useListContextHistory.ts'; -import { isHistoryMessage } from '#modules/platform-context/api/utils.ts'; import { usePlatformContext } from '#modules/platform-context/contexts/index.ts'; import { MessagesContext } from './messages-context'; export function MessagesProvider({ children }: PropsWithChildren) { - const { contextId, history: initialHistory } = usePlatformContext(); - - const { data: history, ...queryRest } = useListContextHistory({ - context_id: contextId ?? undefined, - query: LIST_CONTEXT_HISTORY_DEFAULT_QUERY, - initialData: initialHistory, - // Ensures newly created messages are not fetched from history - initialPageParam: initialHistory?.next_page_token ?? undefined, - // Ensures history is not fetched for newly created contexts, where previous rule isn't sufficient to prevent message duplication - enabled: Boolean(initialHistory), - }); + const { initialTasks } = usePlatformContext(); const [messages, getMessages, setMessages] = useImmerWithGetter( - convertHistoryToUIMessages(history ?? []), + convertTasksToUIMessages(initialTasks ?? []), ); - useEffect(() => { - if (history) { - setMessages((messages) => { - const lastMessage = messages.at(-1); - const lastMessageHistoryIndex = lastMessage - ? history.findIndex(({ data }) => - isHistoryMessage(data) - ? data.messageId === lastMessage?.id - : isAgentMessage(lastMessage) && data.artifactId === lastMessage?.artifactId, - ) - : null; - - const historyContainsLastMessage = lastMessageHistoryIndex !== null && lastMessageHistoryIndex >= 0; - const newItems = historyContainsLastMessage ? history.slice(lastMessageHistoryIndex) : history; - - // Remove last message and convert it again from history, because - // newly fetched history can contain subsequent trajectories of the message - if (historyContainsLastMessage) { - messages.splice(-1, 1); - } - - messages.push(...convertHistoryToUIMessages(newItems)); - }); - } - }, [history, setMessages]); - - const { fetchNextPage, isFetching, hasNextPage } = queryRest; - const { ref: fetchNextPageInViewAnchorRef } = useFetchNextPage({ - fetchNextPage, - isFetching, - hasNextPage, - }); - const isLastMessage = useCallback((message: UIMessage) => getMessages().at(0)?.id === message.id, [getMessages]); const value = useMemo( @@ -76,11 +28,13 @@ export function MessagesProvider({ children }: PropsWithChildren) { setMessages, isLastMessage, queryControl: { - ...queryRest, - fetchNextPageInViewAnchorRef, + fetchNextPageInViewAnchorRef: { current: null } as React.RefObject, + isFetching: false, + isFetchingNextPage: false, + hasNextPage: false, }, }), - [messages, getMessages, setMessages, isLastMessage, queryRest, fetchNextPageInViewAnchorRef], + [messages, getMessages, setMessages, isLastMessage], ); return {children}; diff --git a/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts b/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts index 512dff3c..6551431e 100644 --- a/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts +++ b/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts @@ -4,11 +4,11 @@ */ 'use client'; +import type { RefObject } from 'react'; import { createContext } from 'react'; import type { Updater } from '#hooks/useImmerWithGetter.ts'; import type { UIMessage } from '#modules/messages/types.ts'; -import type { useListContextHistory } from '#modules/platform-context/api/queries/useListContextHistory.ts'; export const MessagesContext = createContext(null); @@ -18,6 +18,9 @@ export interface MessagesContextValue { getMessages: () => UIMessage[]; setMessages: Updater; queryControl: { - fetchNextPageInViewAnchorRef: (node?: Element | null) => void; - } & Omit, 'data'>; + fetchNextPageInViewAnchorRef: RefObject; + isFetching: boolean; + isFetchingNextPage: boolean; + hasNextPage: boolean; + }; } diff --git a/apps/adk-ui/src/modules/platform-context/api/constants.ts b/apps/adk-ui/src/modules/platform-context/api/constants.ts index 056f2afb..dc3d1db9 100644 --- a/apps/adk-ui/src/modules/platform-context/api/constants.ts +++ b/apps/adk-ui/src/modules/platform-context/api/constants.ts @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ListContextHistoryRequest, ListContextsRequest } from '@kagenti/adk'; +import type { ListContextsRequest } from '@kagenti/adk'; export const LIST_CONTEXTS_DEFAULT_QUERY: ListContextsRequest['query'] = { limit: 10, include_empty: false }; - -export const LIST_CONTEXT_HISTORY_DEFAULT_QUERY: ListContextHistoryRequest['query'] = { limit: 10 }; diff --git a/apps/adk-ui/src/modules/platform-context/api/index.ts b/apps/adk-ui/src/modules/platform-context/api/index.ts index 5bdc9b2c..a4e9ef47 100644 --- a/apps/adk-ui/src/modules/platform-context/api/index.ts +++ b/apps/adk-ui/src/modules/platform-context/api/index.ts @@ -7,13 +7,11 @@ import type { CreateContextRequest, CreateContextTokenRequest, DeleteContextRequest, - ListContextHistoryRequest, ListContextsRequest, } from '@kagenti/adk'; import { type MatchModelProvidersRequest, unwrapResult } from '@kagenti/adk'; import { adkClient } from '#api/adk-client.ts'; -import { fetchEntity } from '#api/utils.ts'; import type { PatchContextMetadataRequest } from './types'; import { contextSchema, listContextsResponseSchema } from './types'; @@ -39,13 +37,6 @@ export async function deleteContext(request: DeleteContextRequest) { return result; } -export async function listContextHistory(request: ListContextHistoryRequest) { - const response = await adkClient.listContextHistory(request); - const result = unwrapResult(response); - - return result; -} - export async function patchContextMetadata(request: PatchContextMetadataRequest) { const response = await adkClient.patchContextMetadata(request); const result = unwrapResult(response, contextSchema); @@ -66,7 +57,3 @@ export async function createContextToken(request: CreateContextTokenRequest) { return result; } - -export async function fetchContextHistory(request: ListContextHistoryRequest) { - return await fetchEntity(() => listContextHistory(request)); -} diff --git a/apps/adk-ui/src/modules/platform-context/api/keys.ts b/apps/adk-ui/src/modules/platform-context/api/keys.ts index 0adf0aa9..c1a8031b 100644 --- a/apps/adk-ui/src/modules/platform-context/api/keys.ts +++ b/apps/adk-ui/src/modules/platform-context/api/keys.ts @@ -3,15 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ListContextHistoryRequest, ListContextsRequest } from '@kagenti/adk'; +import type { ListContextsRequest } from '@kagenti/adk'; export const contextKeys = { all: () => ['contexts'] as const, lists: () => [...contextKeys.all(), 'list'] as const, list: ({ query = {} }: ListContextsRequest) => [...contextKeys.lists(), query] as const, - histories: () => [...contextKeys.all(), 'history'] as const, - history: ({ context_id, query = {} }: ListContextHistoryRequest) => - [...contextKeys.histories(), context_id, query] as const, tokens: () => [...contextKeys.all(), 'token'] as const, token: (contextId: string, providerId: string) => [...contextKeys.tokens(), contextId, providerId] as const, }; diff --git a/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts b/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts deleted file mode 100644 index 61e4ac3e..00000000 --- a/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2026 © IBM Corp. - * SPDX-License-Identifier: Apache-2.0 - */ - -import { useInfiniteQuery } from '@tanstack/react-query'; -import type { ListContextHistoryRequest, ListContextHistoryResponse } from '@kagenti/adk'; - -import type { PartialBy } from '#@types/utils.ts'; -import { isNotNull } from '#utils/helpers.ts'; - -import { listContextHistory } from '..'; -import { contextKeys } from '../keys'; - -type Params = PartialBy & { - initialData?: ListContextHistoryResponse; - enabled?: boolean; - initialPageParam?: string; -}; - -export function useListContextHistory({ - context_id, - query: queryParams, - initialData, - initialPageParam, - enabled = true, -}: Params) { - const query = useInfiniteQuery({ - queryKey: contextKeys.history({ - context_id: context_id!, - query: queryParams, - }), - queryFn: ({ pageParam }: { pageParam?: string }) => { - return listContextHistory({ - context_id: context_id!, - query: { - ...queryParams, - page_token: pageParam, - }, - }); - }, - initialPageParam, - getNextPageParam: (lastPage) => { - return lastPage?.has_more && lastPage.next_page_token ? lastPage.next_page_token : undefined; - }, - select: (data) => { - if (!data) { - return undefined; - } - - const items = data.pages.flatMap((page) => page?.items).filter(isNotNull); - - return items; - }, - enabled: Boolean(context_id) && enabled, - initialData: initialData ? { pages: [initialData], pageParams: [undefined] } : undefined, - }); - - return query; -} diff --git a/apps/adk-ui/src/modules/platform-context/api/types.ts b/apps/adk-ui/src/modules/platform-context/api/types.ts index 73705262..70a84929 100644 --- a/apps/adk-ui/src/modules/platform-context/api/types.ts +++ b/apps/adk-ui/src/modules/platform-context/api/types.ts @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ContextHistory } from '@kagenti/adk'; import { contextSchema as sdkContextSchema, listContextsResponseSchema as sdkListContextsResponseSchema, @@ -41,7 +40,3 @@ export const patchContextMetadataRequestSchema = sdkPatchContextMetadataRequestS }); export type PatchContextMetadataRequest = z.infer; -// - -export type HistoryItem = ContextHistory['data']; -export type HistoryMessage = Extract; diff --git a/apps/adk-ui/src/modules/platform-context/api/utils.ts b/apps/adk-ui/src/modules/platform-context/api/utils.ts deleted file mode 100644 index 31f021f8..00000000 --- a/apps/adk-ui/src/modules/platform-context/api/utils.ts +++ /dev/null @@ -1,10 +0,0 @@ -/** - * Copyright 2026 © IBM Corp. - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { HistoryItem, HistoryMessage } from './types'; - -export function isHistoryMessage(item: HistoryItem): item is HistoryMessage { - return 'messageId' in item; -} diff --git a/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx b/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx index 470e7096..c249b8ba 100644 --- a/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx +++ b/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ 'use client'; -import type { ListContextHistoryResponse } from '@kagenti/adk'; +import type { Task } from '@kagenti/adk'; import { type PropsWithChildren, useCallback, useState } from 'react'; import type { Agent } from '#modules/agents/api/types.ts'; @@ -14,10 +14,10 @@ import { PlatformContext } from './platform-context'; interface Props { contextId?: string; - history?: ListContextHistoryResponse; + initialTasks?: Task[]; } -export function PlatformContextProvider({ history, contextId: contextIdProp, children }: PropsWithChildren) { +export function PlatformContextProvider({ initialTasks, contextId: contextIdProp, children }: PropsWithChildren) { const [contextId, setContextId] = useState(contextIdProp ?? null); const { mutateAsync: createContext } = useCreateContext({ @@ -61,7 +61,7 @@ export function PlatformContextProvider({ history, contextId: contextIdProp, chi ContextId; resetContext: () => void; diff --git a/apps/adk-ui/src/modules/runs/components/AgentRun.tsx b/apps/adk-ui/src/modules/runs/components/AgentRun.tsx index 45e12ed8..9d05e5e6 100644 --- a/apps/adk-ui/src/modules/runs/components/AgentRun.tsx +++ b/apps/adk-ui/src/modules/runs/components/AgentRun.tsx @@ -5,9 +5,8 @@ import { notFound } from 'next/navigation'; +import { fetchTasksForContext } from '#api/a2a/list-tasks.ts'; import { runtimeConfig } from '#contexts/App/runtime-config.ts'; -import { LIST_CONTEXT_HISTORY_DEFAULT_QUERY } from '#modules/platform-context/api/constants.ts'; -import { fetchContextHistory } from '#modules/platform-context/api/index.ts'; import { PlatformContextProvider } from '#modules/platform-context/contexts/PlatformContextProvider.tsx'; import { RunView } from '#modules/runs/components/RunView.tsx'; @@ -23,12 +22,7 @@ export async function AgentRun({ providerId, contextId }: Props) { const { featureFlags } = runtimeConfig; const agentPromise = fetchAgent(providerId); - const contextHistoryPromise = contextId - ? fetchContextHistory({ - context_id: contextId, - query: LIST_CONTEXT_HISTORY_DEFAULT_QUERY, - }) - : undefined; + const tasksPromise = contextId ? fetchTasksForContext(providerId, contextId) : undefined; const agent = await agentPromise; @@ -40,14 +34,14 @@ export async function AgentRun({ providerId, contextId }: Props) { } } - const contextHistory = await contextHistoryPromise; + const tasksResponse = await tasksPromise; - if (contextId && !contextHistory) { + if (contextId && !tasksResponse) { notFound(); } return ( - + ); diff --git a/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx b/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx index 5bed7bcb..e2131515 100644 --- a/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx +++ b/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx @@ -294,7 +294,6 @@ function AgentRunProvider({ agent, children }: PropsWithChildren) { pendingSubscription.current = undefined; queryClient.invalidateQueries({ queryKey: contextKeys.lists() }); - queryClient.invalidateQueries({ queryKey: contextKeys.history({ context_id: contextId }) }); } }, [ diff --git a/docs/development/agent-integration/canvas.mdx b/docs/development/agent-integration/canvas.mdx index e3817925..a2bbc5b1 100644 --- a/docs/development/agent-integration/canvas.mdx +++ b/docs/development/agent-integration/canvas.mdx @@ -107,7 +107,6 @@ async def code_agent( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], canvas: Annotated[CanvasExtensionServer, CanvasExtensionSpec()], ): - await context.store(message) canvas_edit = await canvas.parse_canvas_edit_request(message=message) # Adapt system prompt based on whether this is an edit or new generation @@ -116,8 +115,6 @@ async def code_agent( artifact = await call_llm(llm, system_prompt, message) yield artifact - await context.store(artifact) - if __name__ == "__main__": server.run() diff --git a/docs/development/agent-integration/multi-turn.mdx b/docs/development/agent-integration/multi-turn.mdx index 61246c65..de6417a4 100644 --- a/docs/development/agent-integration/multi-turn.mdx +++ b/docs/development/agent-integration/multi-turn.mdx @@ -9,10 +9,7 @@ When building conversational AI agents, one of the key requirements is maintaini | Operation | Purpose | | :--- | :--- | -| **await context.store(input)** | Stores current user message in conversation history. Storage of messages must be explicitly requested| -| **await context.store(response)** | Stores agent’s responses in conversation history, and must be explicitly requested | -| **context: RunContext)** | Sets up a RunContext instance for storing and accessing the conversation history | -| **context_store=PlatformContextStore()** | Configures server to use the platform’s persistent context store to maintain conversation history across agent restarts | +| **context: RunContext)** | Sets up a RunContext instance for accessing the conversation history | ## Simple History Access Example @@ -40,9 +37,6 @@ server = Server() async def basic_history_example(input: Message, context: RunContext): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - # Get the current user message current_message = get_message_text(input) print(f"Current message: {current_message}") @@ -57,9 +51,6 @@ async def basic_history_example(input: Message, context: RunContext): message = AgentMessage(text=f"Hello! I can see we have {len(history)} messages in our conversation.") yield message - # Store the message in the context store - await context.store(message) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) @@ -72,10 +63,8 @@ if __name__ == "__main__": ### Steps -1. **Access conversation history:** Use `RunContext` to set up an instance of the conversation history to store and load previous messages. -1. **Store incoming messages:** Use `await context.store(input)` to store the current user message in the conversation history. +1. **Access conversation history:** Use `RunContext` to set up an instance of the conversation history to load previous messages. 1. **Filter and process history:** Retrieve the conversation history with `load_history()` and filter to get the messages relevant to your agent's logic. -1. **Store agent responses:** Use `await context.store(response)` to store your agent's responses for future conversation context. ## Streaming with Buffered History Example @@ -95,7 +84,6 @@ from a2a.utils.message import get_message_text from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -120,9 +108,6 @@ async def streaming_agent_w_single_history_write_example(input: Message, context Stream partial answers, execute tools, and persist one finalized assistant message. See other examples for actual implementation of multi-turn conversation agent with tool use. """ - # Store the user input as the first persisted item for this turn. - await context.store(data=input) - history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] current_message = get_message_text(input) @@ -159,20 +144,14 @@ async def streaming_agent_w_single_history_write_example(input: Message, context # This does not need to be the go-to approach in all cases, sometimes the partial outputs are of no value and one does not want them to be properly stored. # # Why not store each chunk? - # - Calling `context.store()`, PlatformContextStore saves every message as a distinct history item. - # - Storing per chunk would fragment one assistant turn into many partial messages. - # - A single aggregated write keeps replay, memory, and history semantics clean. - # aggregated_response = AgentMessage(text="\n".join(buffered_parts)) yield "Final result check:\n" + str(aggregated_response.text) - await context.store(data=aggregated_response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), ) @@ -183,9 +162,8 @@ if __name__ == "__main__": ### When to use buffering -- Use **simple yield + store** when your agent emits a single final response. -- Use **stream + buffer + single store** when your agent emits multiple partial chunks which are streamed to the user. -- With `PlatformContextStore`, each `context.store()` call creates a persisted history item, so buffering prevents chunk-level history fragmentation. +- Use **simple yield** when your agent emits a single final response. +- Use **stream + buffer** when your agent emits multiple partial chunks which are streamed to the user. ## Advanced BeeAI Framework Example @@ -205,7 +183,6 @@ from kagenti_adk.a2a.extensions import LLMServiceExtensionServer, LLMServiceExte from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ConditionalRequirement @@ -236,8 +213,6 @@ async def advanced_history_example( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # Load conversation history history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] @@ -273,14 +248,12 @@ async def advanced_history_example( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) @@ -295,31 +268,9 @@ This advanced example demonstrates several key concepts: - **Framework Integration:** Leverages the BeeAI Framework for sophisticated agent capabilities - **Memory Management:** Converts conversation history to framework format and loads it into agent memory - **Tool Usage:** Includes thinking tools and conditional requirements for better reasoning -- **Persistent Storage:** Uses `PlatformContextStore` for conversation persistence ## Using Content History -### Persistent Storage Example - -By default, conversation history is stored in memory and is lost when the agent process restarts. For production applications, you'll want to use persistent context storage to maintain conversation history across agent restarts. The `PlatformContextStore` automatically handles conversation persistence, ensuring that users can continue their conversations even after agent restarts or deployments. - -```python -import os -from kagenti_adk.server import Server -from kagenti_adk.server.store.platform_context_store import PlatformContextStore - -server = Server() - -def run(): - server.run( - host=os.getenv("HOST", "127.0.0.1"), - port=int(os.getenv("PORT", 8000)), - context_store=PlatformContextStore() - ) -``` - - - ### History Contents The `context.load_history()` method returns an async iterator containing all items in the conversation, including the current message. This can include: @@ -339,49 +290,7 @@ The history includes the current message, so if you want only previous messages, The history iterator returns all message types. Always filter messages using `isinstance(message, Message)` to ensure you're working with the correct message format. -### Editing and Removing Messages from History - -Sometimes you may need to edit a previous message in a conversation or remove messages that are no longer relevant. -The Kagenti ADK provides a mechanism to delete history items from a specific point onward, allowing you to effectively “rewind” the conversation and replace a message with an edited version. -Possible use cases include editing a previous message, clearing irrelevant exchanges, or removing messages that resulted from processing errors. - -Here's an example of a function for editing a user message in a conversation using the context API. This assumes you know the context message id, which can be obtained as an id field of an object returned by `RunContext.load_history(load_history_items=True)`, `Context.list_history` or `Context.list_all_history`. - -```python -import uuid -from uuid import UUID -from typing import Any - -from a2a.types import Message, Part, Role -from kagenti_adk.platform.context import Context -from kagenti_adk.server.context import RunContext - -async def edit_message_in_context(run_context: RunContext, id: UUID, new_text: str): - # Step 1: Delete from this message onwards - await run_context.delete_history_from_id(from_id=id) - - # Step 2: Create the corrected message - corrected_message = Message( - message_id=str(uuid.uuid4()), - parts=[Part(text=new_text)], - role=Role.ROLE_USER, - ) - - # Step 3: Store the corrected message - await run_context.store(data=corrected_message) -``` - - -When you delete history from a specific message onwards, all messages created after that point (including the message itself) are removed. This effectively creates a new conversation branch starting from the message before the deleted one. - - - -This operation is permanent. Once messages are deleted, they cannot be recovered. Consider informing users about this operation or implementing a confirmation step for important conversations. - - -### Message Storage Guidelines - -Since messages are not automatically stored, you need to explicitly call `context.store()` for any message you want to be available in future interactions. Here are the key guidelines: +### Message History #### Store Request Example @@ -389,15 +298,9 @@ Since messages are not automatically stored, you need to explicitly call `contex ```python @server.agent() async def my_agent(input: Message, context: RunContext): - # Store the incoming user message immediately - await context.store(input) - # Process the message and generate response response = AgentMessage(text="Your response here") yield response - - # Store the agent's response after yielding - await context.store(response) ``` #### What to Store diff --git a/docs/development/agent-integration/rag.mdx b/docs/development/agent-integration/rag.mdx index 3e7180d4..75543a7f 100644 --- a/docs/development/agent-integration/rag.mdx +++ b/docs/development/agent-integration/rag.mdx @@ -548,7 +548,6 @@ async def conversation_rag_agent_example( vector_store = await create_vector_store(embedding_client, embedding_model) # store vector store id in context for future messages data_part = DataPart(data={"vector_store_id": vector_store.id}) - await context.store(AgentMessage(parts=[data_part])) # Process files, add to vector store for file in files: diff --git a/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py b/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py index cc2a06c9..0be528cc 100644 --- a/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py +++ b/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py @@ -75,7 +75,6 @@ async def canvas_with_llm_example( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], canvas: Annotated[CanvasExtensionServer, CanvasExtensionSpec()], ): - await context.store(message) canvas_edit = await canvas.parse_canvas_edit_request(message=message) # Adapt system prompt based on whether this is an edit or new generation @@ -84,8 +83,6 @@ async def canvas_with_llm_example( artifact = await call_llm(llm, system_prompt, message) yield artifact - await context.store(artifact) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) diff --git a/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py b/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py index eae75f4f..3a9734bd 100644 --- a/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py +++ b/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py @@ -10,7 +10,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ConditionalRequirement @@ -41,8 +40,6 @@ async def advanced_history_example( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # Load conversation history history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] @@ -78,14 +75,12 @@ async def advanced_history_example( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) diff --git a/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py b/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py index f17cc5e5..fc38c9df 100644 --- a/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py +++ b/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py @@ -17,9 +17,6 @@ async def basic_history_example(input: Message, context: RunContext): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - # Get the current user message current_message = get_message_text(input) print(f"Current message: {current_message}") @@ -34,9 +31,6 @@ async def basic_history_example(input: Message, context: RunContext): message = AgentMessage(text=f"Hello! I can see we have {len(history)} messages in our conversation.") yield message - # Store the message in the context store - await context.store(message) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) diff --git a/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py b/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py index 6ebd02b5..5bf21dbc 100644 --- a/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py +++ b/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py @@ -8,7 +8,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -33,9 +32,6 @@ async def streaming_agent_w_single_history_write_example(input: Message, context Stream partial answers, execute tools, and persist one finalized assistant message. See other examples for actual implementation of multi-turn conversation agent with tool use. """ - # Store the user input as the first persisted item for this turn. - await context.store(data=input) - history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] current_message = get_message_text(input) @@ -72,20 +68,14 @@ async def streaming_agent_w_single_history_write_example(input: Message, context # This does not need to be the go-to approach in all cases, sometimes the partial outputs are of no value and one does not want them to be properly stored. # # Why not store each chunk? - # - Calling `context.store()`, PlatformContextStore saves every message as a distinct history item. - # - Storing per chunk would fragment one assistant turn into many partial messages. - # - A single aggregated write keeps replay, memory, and history semantics clean. - # aggregated_response = AgentMessage(text="\n".join(buffered_parts)) yield "Final result check:\n" + str(aggregated_response.text) - await context.store(data=aggregated_response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), ) diff --git a/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py b/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py index b5a7bdd8..9d1bc9fa 100644 --- a/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py +++ b/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py @@ -85,7 +85,6 @@ async def conversation_rag_agent_example( vector_store = await create_vector_store(embedding_client, embedding_model) # store vector store id in context for future messages data_part = DataPart(data={"vector_store_id": vector_store.id}) - await context.store(AgentMessage(parts=[data_part])) # Process files, add to vector store for file in files: diff --git a/skills/kagenti-adk-wrapper/SKILL.md b/skills/kagenti-adk-wrapper/SKILL.md index bded669b..4675d6d5 100644 --- a/skills/kagenti-adk-wrapper/SKILL.md +++ b/skills/kagenti-adk-wrapper/SKILL.md @@ -159,7 +159,7 @@ Read the agent's code and classify it. This determines the `interaction_mode` va This classification determines: -- How to use `context.store()` and `context.load_history()`: persist input/response by default for all agents; `context.load_history()` is required for multi-turn, and optional for single-turn (use only when prior context is intentionally part of behavior) +- How to use `context.load_history()`: history is auto-persisted by the A2A TaskStore; `context.load_history()` is required for multi-turn, and optional for single-turn (use only when prior context is intentionally part of behavior) - Whether to define an `initial_form` for structured inputs (single-turn with named parameters) --- @@ -277,8 +277,7 @@ When building and testing the wrapper, ensure you avoid these common pitfalls: - **Never use synchronous functions for the agent handler.** Agent functions must be `async def` generators using `yield`. - **Never hide platform wiring behind abstraction layers.** Keep `@server.agent(...)`, extension parameters, and integration contracts visible in the main entrypoint so behavior is auditable. - **Never treat runtime inspection as first source.** `kagenti_adk` and `a2a` details must come from provided docs first; use installed-environment inspection only as documented fallback, then validate imports at the end. -- **Never assume history is auto-saved.** Explicitly call `await context.store(input)` and `await context.store(response)`. -- **Never assume persistent history without `PlatformContextStore`.** Without it, context storage is in-memory and lost on restart. +- **History is auto-saved by the A2A framework.** Messages and artifacts are persisted in the A2A TaskStore automatically — do not manually store them. - **Never forget to filter history.** `context.load_history()` returns Messages and Artifacts. Filter with `isinstance(message, Message)`. - **Never store individual streaming chunks.** Accumulate the full response and store once. - **Never treat extension data as dictionaries.** Use dot notation (e.g., `config.api_key`, not `config.get("api_key")`). @@ -343,9 +342,7 @@ After wrapping, confirm: ### Context & History -- [ ] `input` and `response` stored via `context.store()` -- [ ] `context_store=PlatformContextStore()` present if context is persisted/read -- [ ] Multi-turn uses `context.load_history()`; single-turn only if intentionally needed +- [ ] Multi-turn uses `context.load_history()` to read conversation history from the A2A TaskStore ### Forms & Files diff --git a/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md b/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md index 218538ec..0820e65a 100644 --- a/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md +++ b/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md @@ -49,18 +49,15 @@ Based on the classification in Step 2, follow exactly ONE of these workflows: - [ ] Pass necessary inputs (from forms or text) to original agent logic - [ ] Yield trajectory for meaningful intermediate activity (same rule as all agents) - [ ] Yield the final response via `AgentMessage(text=result)` -- [ ] Persist both input and response via `context.store()` ``` ### If the agent is Multi-turn ``` -- [ ] Store input: Save incoming user message immediately with `await context.store(input)` - [ ] Load history: Retrieve past conversation via `[msg async for msg in context.load_history() if isinstance(msg, Message)]` - [ ] Execute agent: Pass the filtered history to the original agent logic - [ ] Yield trajectory for meaningful intermediate activity (same rule as all agents) - [ ] Yield response: Return final answering chunks with `yield AgentMessage(text=...)` -- [ ] Store response: Save the final response with `await context.store(response)` ``` ## Entrypoint @@ -68,6 +65,6 @@ Based on the classification in Step 2, follow exactly ONE of these workflows: Create a `run()` / `serve()` function protected by an `if __name__ == "__main__":` guard. This function should call `server.run()`: - The server should be configured to listen on a `host` and `port` from environment variables (e.g., `host=os.getenv("HOST", "127.0.0.1")`, `port=int(os.getenv("PORT", 8000))`). -- If the agent persists or reads context history, you must pass `context_store=PlatformContextStore()` to `server.run()`. +- Conversation history is automatically persisted in the A2A TaskStore. Use `context.load_history()` to read it. - **Remove all CLI argument parsing** (`argparse`). Map required CLI inputs to the wrapper parameters instead (e.g., from Forms, Settings, or Environment variables). - Only `auth_backend` if explicitly requested.