From 7dae81cb68cb90d2c1550e322292aa9899f28cd0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 12:57:18 -0800 Subject: [PATCH 1/9] tmp --- llama_stack/apis/agents/agents.py | 6 +++--- .../agents/meta_reference/agent_instance.py | 19 +++++++++++++++++-- tests/client-sdk/agents/test_agents.py | 12 ++++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 2f374b6388..382a67a576 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -344,15 +344,15 @@ async def create_agent_turn( ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( - route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/submit_tool_response_messages", + route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/tool_responses", method="POST", ) - async def submit_tool_response_messages( + async def submit_tool_responses( self, agent_id: str, session_id: str, turn_id: str, - tool_response_messages: List[ToolResponseMessage], + tool_responses: Dict[str, ToolResponseMessage], ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f1..779dcf74d1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnPendingPayload, AgentTurnResponseTurnStartPayload, Attachment, Document, @@ -62,7 +63,11 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolCall, + ToolParamDefinition, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -222,6 +227,15 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn ) await self.storage.add_turn_to_session(request.session_id, turn) + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnPendingPayload( + turn=turn, + ) + ) + ) + else: chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnCompletePayload( @@ -229,7 +243,8 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn ) ) ) - yield chunk + + yield chunk async def run( self, diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5380d357a..6b8caec252 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -19,8 +19,12 @@ from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig -from llama_stack.apis.agents.agents import ToolChoice +from llama_stack.apis.agents.agents import ( + AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( + ToolChoice, +) class TestClientTool(ClientTool): @@ -314,6 +318,10 @@ def test_custom_tool(llama_stack_client, agent_config): ], session_id=session_id, ) + from rich.pretty import pprint + + for x in response: + pprint(x) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) From 07c9222b6f77723405c7439a6850058e4bb5cc0d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 14:54:45 -0800 Subject: [PATCH 2/9] debug --- .../agents/meta_reference/agent_instance.py | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 779dcf74d1..743b77adf8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -30,8 +30,8 @@ AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, - AgentTurnResponseTurnPendingPayload, AgentTurnResponseTurnStartPayload, Attachment, Document, @@ -225,26 +225,27 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn completed_at=datetime.now(), steps=steps, ) - await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnPendingPayload( - turn=turn, + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) ) ) - ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) ) ) - ) + # only add to storage if turn is complete + await self.storage.add_turn_to_session(request.session_id, turn) - yield chunk + yield chunk async def run( self, @@ -626,11 +627,7 @@ async def _run( input_messages = input_messages + [message] else: log.info(f"{str(message)}") - tool_call = message.tool_calls[0] - if tool_call.tool_name in client_tools: - yield message - return - + # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -640,6 +637,8 @@ async def _run( ) ) ) + + tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -654,6 +653,12 @@ async def _run( ) ) + # If tool is a client tool, yield CompletionMessage and return + if tool_call.tool_name in client_tools: + yield message + return + + # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value From fb0d992f9972e0e9d63f2dbc7350069898bfda3a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 15:48:55 -0800 Subject: [PATCH 3/9] temp debuug --- debug_custom_tool.py | 115 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 debug_custom_tool.py diff --git a/debug_custom_tool.py b/debug_custom_tool.py new file mode 100644 index 0000000000..26418072d9 --- /dev/null +++ b/debug_custom_tool.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Dict, List +from uuid import uuid4 + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.client_tool import ClientTool +from llama_stack_client.types import ToolResponseMessage +from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig +from llama_stack_client.types.tool_def_param import Parameter +from rich.pretty import pprint + + +class TestClientTool(ClientTool): + """Tool to give boiling point of a liquid + Returns the correct value for polyjuice in Celcius and Fahrenheit + and returns -1 for other liquids + """ + + def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + assert len(messages) == 1, "Expected single message" + + message = messages[0] + + tool_call = message.tool_calls[0] + + try: + response = self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + message = ToolResponseMessage( + role="tool", + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + ) + return message + + def get_name(self) -> str: + return "get_boiling_point" + + def get_description(self) -> str: + return "Get the boiling point of imaginary liquids (eg. polyjuice)" + + def get_params_definition(self) -> Dict[str, Parameter]: + return { + "liquid_name": Parameter( + name="liquid_name", + parameter_type="string", + description="The name of the liquid", + required=True, + ), + "celcius": Parameter( + name="celcius", + parameter_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + } + + def run_impl(self, liquid_name: str, celcius: bool = True) -> int: + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 + + +if __name__ == "__main__": + tool = TestClientTool() + agent_config = AgentConfig( + model="meta-llama/Llama-3.1-8B-Instruct", + instructions="You are a helpful assistant", + sampling_params={ + "strategy": { + "type": "top_p", + "temperature": 1.0, + "top_p": 0.9, + }, + }, + toolgroups=[], + input_shields=[], + output_shields=[], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="json", + ), + client_tools=[tool.get_tool_definition()], + enable_session_persistence=False, + ) + client = LlamaStackClient(base_url="http://localhost:8321") + agent = Agent(client, agent_config, client_tools=(tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + simple_hello = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice in Celcius?", + } + ], + session_id=session_id, + ) + for chunk in simple_hello: + pprint(chunk) From 96c521ada631ebacbecf238acac552b2af46a46c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 16:30:27 -0800 Subject: [PATCH 4/9] temp debug --- .../providers/inline/agents/meta_reference/agent_instance.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 743b77adf8..8da3f3a141 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -225,6 +225,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn completed_at=datetime.now(), steps=steps, ) + await self.storage.add_turn_to_session(request.session_id, turn) if output_message.tool_calls: chunk = AgentTurnResponseStreamChunk( @@ -242,8 +243,6 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn ) ) ) - # only add to storage if turn is complete - await self.storage.add_turn_to_session(request.session_id, turn) yield chunk From 5fbb159cf648c77df0633c6d0f42b580e25ef340 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 16:48:17 -0800 Subject: [PATCH 5/9] fix test --- debug_custom_tool.py | 115 ------------------------- tests/client-sdk/agents/test_agents.py | 4 - 2 files changed, 119 deletions(-) delete mode 100644 debug_custom_tool.py diff --git a/debug_custom_tool.py b/debug_custom_tool.py deleted file mode 100644 index 26418072d9..0000000000 --- a/debug_custom_tool.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -from typing import Dict, List -from uuid import uuid4 - -from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.client_tool import ClientTool -from llama_stack_client.types import ToolResponseMessage -from llama_stack_client.types.shared.completion_message import CompletionMessage -from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig -from llama_stack_client.types.tool_def_param import Parameter -from rich.pretty import pprint - - -class TestClientTool(ClientTool): - """Tool to give boiling point of a liquid - Returns the correct value for polyjuice in Celcius and Fahrenheit - and returns -1 for other liquids - """ - - def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, "Expected single message" - - message = messages[0] - - tool_call = message.tool_calls[0] - - try: - response = self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) - except Exception as e: - response_str = f"Error when running tool: {e}" - - message = ToolResponseMessage( - role="tool", - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response_str, - ) - return message - - def get_name(self) -> str: - return "get_boiling_point" - - def get_description(self) -> str: - return "Get the boiling point of imaginary liquids (eg. polyjuice)" - - def get_params_definition(self) -> Dict[str, Parameter]: - return { - "liquid_name": Parameter( - name="liquid_name", - parameter_type="string", - description="The name of the liquid", - required=True, - ), - "celcius": Parameter( - name="celcius", - parameter_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - } - - def run_impl(self, liquid_name: str, celcius: bool = True) -> int: - if liquid_name.lower() == "polyjuice": - if celcius: - return -100 - else: - return -212 - else: - return -1 - - -if __name__ == "__main__": - tool = TestClientTool() - agent_config = AgentConfig( - model="meta-llama/Llama-3.1-8B-Instruct", - instructions="You are a helpful assistant", - sampling_params={ - "strategy": { - "type": "top_p", - "temperature": 1.0, - "top_p": 0.9, - }, - }, - toolgroups=[], - input_shields=[], - output_shields=[], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="json", - ), - client_tools=[tool.get_tool_definition()], - enable_session_persistence=False, - ) - client = LlamaStackClient(base_url="http://localhost:8321") - agent = Agent(client, agent_config, client_tools=(tool,)) - session_id = agent.create_session(f"test-session-{uuid4()}") - simple_hello = agent.create_turn( - messages=[ - { - "role": "user", - "content": "What is the boiling point of polyjuice in Celcius?", - } - ], - session_id=session_id, - ) - for chunk in simple_hello: - pprint(chunk) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 6b8caec252..781095d2b4 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -318,10 +318,6 @@ def test_custom_tool(llama_stack_client, agent_config): ], session_id=session_id, ) - from rich.pretty import pprint - - for x in response: - pprint(x) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) From e2bfd165d2bf66c2e0279f9934eef4e980c3febb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 21:49:06 -0800 Subject: [PATCH 6/9] add flag allow_turn_resume --- docs/_static/llama-stack-spec.html | 3 +++ docs/_static/llama-stack-spec.yaml | 2 ++ llama_stack/apis/agents/agents.py | 4 ++++ .../providers/inline/agents/meta_reference/agent_instance.py | 2 +- llama_stack/providers/inline/agents/meta_reference/agents.py | 2 ++ 5 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 0a5d93d802..516a0174dc 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4287,6 +4287,9 @@ }, "tool_config": { "$ref": "#/components/schemas/ToolConfig" + }, + "allow_turn_resume": { + "type": "boolean" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c05eef95eb..c7f4399825 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2779,6 +2779,8 @@ components: $ref: '#/components/schemas/AgentTool' tool_config: $ref: '#/components/schemas/ToolConfig' + allow_turn_resume: + type: boolean additionalProperties: false required: - messages diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 8fde864e43..eb1cdde903 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -296,6 +296,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): stream: Optional[bool] = False tool_config: Optional[ToolConfig] = None + # TODO (xiyan): used for backward compatibility, update for 0.1.5 + allow_turn_resume: Optional[bool] = False + @json_schema_type class AgentTurnResumeRequest(BaseModel): @@ -352,6 +355,7 @@ async def create_agent_turn( documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8da3f3a141..77c9c86296 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -227,7 +227,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn ) await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: + if output_message.tool_calls and request.allow_turn_resume: chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnAwaitingInputPayload( diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index acacbdfdf0..8921d56285 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -146,6 +146,7 @@ async def create_agent_turn( documents: Optional[List[Document]] = None, stream: Optional[bool] = False, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -155,6 +156,7 @@ async def create_agent_turn( toolgroups=toolgroups, documents=documents, tool_config=tool_config, + allow_turn_resume=allow_turn_resume, ) if stream: return self._create_agent_turn_streaming(request) From db764e7ed61506ec00542456fe86fc7981546d3e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 22:55:05 -0800 Subject: [PATCH 7/9] add doc --- llama_stack/apis/agents/agents.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index eb1cdde903..c904fdbef3 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -296,7 +296,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): stream: Optional[bool] = False tool_config: Optional[ToolConfig] = None - # TODO (xiyan): used for backward compatibility, update for 0.1.5 + # TODO (xiyan): temporary flag, will remove for 0.1.5 allow_turn_resume: Optional[bool] = False @@ -369,7 +369,19 @@ async def resume_agent_turn( turn_id: str, tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, - ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... + ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: + """Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready. + + :param agent_id: The ID of the agent to resume. + :param session_id: The ID of the session to resume. + :param turn_id: The ID of the turn to resume. + :param tool_responses: The tool call responses to resume the turn with. + :param stream: Whether to stream the response. + :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. + """ + ... @webmethod( route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", From 4f2427c6c8f7d440e22dca05ba8e16e5d6d7a04e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 22:55:22 -0800 Subject: [PATCH 8/9] add doc --- docs/_static/llama-stack-spec.html | 13 +++++++++---- docs/_static/llama-stack-spec.yaml | 17 ++++++++++++++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 516a0174dc..8e8b8804f8 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2319,7 +2319,7 @@ "post": { "responses": { "200": { - "description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.", + "description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.", "content": { "application/json": { "schema": { @@ -2337,11 +2337,12 @@ "tags": [ "Agents" ], - "description": "", + "description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.", "parameters": [ { "name": "agent_id", "in": "path", + "description": "The ID of the agent to resume.", "required": true, "schema": { "type": "string" @@ -2350,6 +2351,7 @@ { "name": "session_id", "in": "path", + "description": "The ID of the session to resume.", "required": true, "schema": { "type": "string" @@ -2358,6 +2360,7 @@ { "name": "turn_id", "in": "path", + "description": "The ID of the turn to resume.", "required": true, "schema": { "type": "string" @@ -8109,10 +8112,12 @@ "type": "array", "items": { "$ref": "#/components/schemas/ToolResponseMessage" - } + }, + "description": "The tool call responses to resume the turn with." }, "stream": { - "type": "boolean" + "type": "boolean", + "description": "Whether to stream the response." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c7f4399825..c14ae3d3ad 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1406,8 +1406,8 @@ paths: responses: '200': description: >- - A single turn in an interaction with an Agentic System. **OR** streamed - agent turn completion response. + A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk + objects. content: application/json: schema: @@ -1417,20 +1417,28 @@ paths: $ref: '#/components/schemas/AgentTurnResponseStreamChunk' tags: - Agents - description: '' + description: >- + Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client + side tool calls, this endpoint can be used to submit the outputs from the + tool calls once they are ready. parameters: - name: agent_id in: path + description: The ID of the agent to resume. required: true schema: type: string - name: session_id in: path + description: The ID of the session to resume. required: true schema: type: string - name: turn_id in: path + description: The ID of the turn to resume. required: true schema: type: string @@ -5244,8 +5252,11 @@ components: type: array items: $ref: '#/components/schemas/ToolResponseMessage' + description: >- + The tool call responses to resume the turn with. stream: type: boolean + description: Whether to stream the response. additionalProperties: false required: - tool_responses From 4dc7f05a2d0356a6332f86ce5581a6079690d623 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 21 Feb 2025 11:43:48 -0800 Subject: [PATCH 9/9] feat(3/n): agent resume_turn (#1194) # What does this PR do? - https://github.com/meta-llama/llama-stack/pull/1178 - https://github.com/meta-llama/llama-stack/pull/1187 - https://github.com/meta-llama/llama-stack/pull/1194 **client changes** - https://github.com/meta-llama/llama-stack-client-python/pull/157 - https://github.com/meta-llama/llama-stack-client-python/pull/158 ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` image **llama-stack-apps** ``` python -m examples.agents.react_agent localhost 8321 ``` - Test with script: https://gist.github.com/yanxi0830/f2e407527f468998a700cd29fd271b15 **Output Before**: we have 2 `turn_id` with 2 turns - https://gist.github.com/yanxi0830/9fbd7a80fcddc784a28c59d4a9c1d943 **Output After**: we have 1 `turn_id`, the final turn have all 3 steps - https://gist.github.com/yanxi0830/17754d56d08ccbeaec419b693137500c image **Telemetry** image [//]: # (## Documentation) --- .../agents/meta_reference/agent_instance.py | 144 ++++++++++++++++-- .../inline/agents/meta_reference/agents.py | 21 ++- .../agents/meta_reference/persistence.py | 14 +- 3 files changed, 168 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 77c9c86296..edd253356f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -33,6 +33,7 @@ AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, + AgentTurnResumeRequest, Attachment, Document, InferenceStep, @@ -156,6 +157,15 @@ def turn_to_messages(self, turn: Turn) -> List[Message]: async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for turn in turns: + messages.extend(self.turn_to_messages(turn)) + return messages + async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) @@ -168,14 +178,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) - - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) - - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) - + messages = await self.get_messages_from_turns(turns) messages.extend(request.messages) turn_id = str(uuid.uuid4()) @@ -246,6 +249,119 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn yield chunk + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + with tracing.span("resume_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + messages = await self.get_messages_from_turns(turns) + messages.extend(request.tool_responses) + + last_turn_messages = [ + x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + ] + + # get the steps from the turn id + steps = [] + if len(turns) > 0: + steps = turns[-1].steps + + # mark tool execution step as complete + # if there's no tool execution in progress step (due to storage, or tool call parsing on client), + # we'll create a new tool execution step with current time + in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) + now = datetime.now() + tool_execution_step = ToolExecutionStep( + step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + turn_id=request.turn_id, + tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_responses=[ + ToolResponse( + call_id=x.call_id, + tool_name=x.tool_name, + content=x.content, + ) + for x in request.tool_responses + ], + completed_at=now, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + ) + steps.append(tool_execution_step) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=tool_execution_step.step_id, + step_details=tool_execution_step, + ) + ) + ) + + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=request.turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + last_turn_start_time = datetime.now() + if len(turns) > 0: + last_turn_start_time = turns[-1].started_at + + turn = Turn( + turn_id=request.turn_id, + session_id=request.session_id, + input_messages=last_turn_messages, + output_message=output_message, + started_at=last_turn_start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk + async def run( self, session_id: str, @@ -636,7 +752,6 @@ async def _run( ) ) ) - tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -654,6 +769,17 @@ async def _run( # If tool is a client tool, yield CompletionMessage and return if tool_call.tool_name in client_tools: + await self.storage.set_in_progress_tool_call_step( + session_id, + turn_id, + ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[], + started_at=datetime.now(), + ), + ) yield message return diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8921d56285..8a4d912382 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ AgentStepResponse, AgentToolGroup, AgentTurnCreateRequest, + AgentTurnResumeRequest, Document, Session, Turn, @@ -179,7 +180,25 @@ async def resume_agent_turn( tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: - pass + request = AgentTurnResumeRequest( + agent_id=agent_id, + session_id=session_id, + turn_id=turn_id, + tool_responses=tool_responses, + stream=stream, + ) + if stream: + return self._continue_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + + async def _continue_agent_turn_streaming( + self, + request: AgentTurnResumeRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) + async for event in agent.resume_turn(request): + yield event async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 4b8ad6d4ad..3c3866873b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -12,7 +12,7 @@ from pydantic import BaseModel -from llama_stack.apis.agents import Turn +from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -84,3 +84,15 @@ async def get_session_turns(self, session_id: str) -> List[Turn]: continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns + + async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + await self.kvstore.set( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + value=step.model_dump_json(), + ) + + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + value = await self.kvstore.get( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + ) + return ToolExecutionStep(**json.loads(value)) if value else None