From e8dc718281c997c73e1e7aa6b5d0a3a2154f0ae6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 3 Feb 2025 15:59:49 -0800 Subject: [PATCH 01/12] react agent --- examples/agents/react_agent.py | 252 +++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 examples/agents/react_agent.py diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py new file mode 100644 index 00000000..48fd8926 --- /dev/null +++ b/examples/agents/react_agent.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +# import os +import uuid +from typing import Dict, List, Union + +import fire + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.client_tool import ClientTool + +# from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage +from llama_stack_client.types.shared.user_message import UserMessage +from llama_stack_client.types.tool_def_param import Parameter +from rich.pretty import pprint + +REACT_PROMPT = """ +You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. +To do so, you have been given access to the following tools: <> +The way you use the tools is by specifying a json blob, ending with ''. +Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool). + +The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB: +{ + "action": $TOOL_NAME, + "action_input": $INPUT +} + +Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. + +You should ALWAYS use the following format: + +Thought: you should always think about one action to take. Then use the action as follows: +Action: +$ACTION_JSON_BLOB +Observation: the result of the action +... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The $ACTION_JSON_BLOB must only use a SINGLE action at a time.) + +You can use the result of the previous action as input for the next action. +The observation will always be a string: it can represent a file, like "image_1.jpg". +Then you can use it as input for the next action. You can do it for instance as follows: + +Observation: "image_1.jpg" + +Thought: I need to transform the image that I received in the previous observation to make it green. +Action: +{ + "action": "image_transformer", + "action_input": {"image": "image_1.jpg"} +} + +To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: +Action: +{ + "action": "final_answer", + "action_input": {"answer": "insert your final answer here"} +} + + +Here are a few examples using notional tools: +--- +Task: "Generate an image of the oldest person in this document." + +Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. +Action: +{ + "action": "document_qa", + "action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} +} + +Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." + + +Thought: I will now generate an image showcasing the oldest person. +Action: +{ + "action": "image_generator", + "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} +} +Observation: "image.png" + +Thought: I will now return the generated image. +Action: +{ + "action": "final_answer", + "action_input": "image.png" +} + +--- +Task: "What is the result of the following operation: 5 + 3 + 1294.678?" + +Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool +Action: +{ + "action": "python_interpreter", + "action_input": {"code": "5 + 3 + 1294.678"} +} +Observation: 1302.678 + +Thought: Now that I know the result, I will now return it. +Action: +{ + "action": "final_answer", + "action_input": "1302.678" +} + +--- +Task: "Which city has the highest population , Guangzhou or Shanghai?" + +Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. +Action: +{ + "action": "search", + "action_input": "Population Guangzhou" +} +Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] + + +Thought: Now let's get the population of Shanghai using the tool 'search'. +Action: +{ + "action": "search", + "action_input": "Population Shanghai" +} +Observation: '26 million (2019)' + +Thought: Now I know that Shanghai has a larger population. Let's return the result. +Action: +{ + "action": "final_answer", + "action_input": "Shanghai" +} + + +Above example were using notional tools that might not exist for you. You only have access to these tools: +<> + +Here are the rules you should always follow to solve your task: +1. ALWAYS provide a single 'Thought:' sequence, and a single 'Action:' sequence that ends with , else you will fail. +2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead. +3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. Observations will be provided to you, no need to generate them + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. +""" + + +class MetaExternalSearchTool(ClientTool): + + def get_name(self) -> str: + return "get_external_meta_data" + + def get_description(self) -> str: + return """ +Search the web for the given query about Meta. Get information Meta available on the public internet +""" + + def get_params_definition(self) -> Dict[str, Parameter]: + return { + "query": Parameter( + name="query", + parameter_type="str", + description="The query to use for querying the internet", + required=True, + ) + } + + def run( + self, messages: List[Union[UserMessage, ToolResponseMessage]] + ) -> List[Union[UserMessage, ToolResponseMessage]]: + print("run_impl for MetaExternalSearchTool called") + dummy_response = """ + torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs. + + torchtune provides: + + PyTorch implementations of popular LLMs from Llama, Gemma, Mistral, Phi, and Qwen model families + Hackable training recipes for full finetuning, LoRA, QLoRA, DPO, PPO, QAT, knowledge distillation, and more + Out-of-the-box memory efficiency, performance improvements, and scaling with the latest PyTorch APIs + YAML configs for easily configuring training, evaluation, quantization or inference recipes + Built-in support for many popular dataset formats and prompt templates + """ + return [ + ToolResponseMessage( + call_id="random-id", + tool_name=self.get_name(), + content=dummy_response, + role="tool", + ) + ] + + +def main(): + client = LlamaStackClient( + base_url="http://localhost:8321", + ) + + model = "meta-llama/Llama-3.3-70B-Instruct" + + client_tools = [ + MetaExternalSearchTool(), + ] + + tool_names = ", ".join([tool.get_name() for tool in client_tools]) + tool_descriptions = "\n".join( + [f"- {tool.get_name()}: {tool.get_description()}" for tool in client_tools] + ) + instruction = REACT_PROMPT.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + + agent_config = AgentConfig( + model=model, + instructions=instruction, + sampling_params={ + "strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9}, + }, + client_tools=[ + client_tool.get_tool_definition() for client_tool in client_tools + ], + tool_choice="auto", + tool_prompt_format="python_list", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + agent = Agent(client, agent_config, client_tools) + + session_id = agent.create_session(f"ttest-session-{uuid.uuid4().hex}") + + response = agent.create_turn( + messages=[ + {"role": "user", "content": "What model families does torchtune support?"} + ], + session_id=session_id, + stream=False, + ) + pprint(response) + + # for chunk in response: + # pprint(chunk) + + +if __name__ == "__main__": + fire.Fire(main) From de942446efa5b88d7de442b6ceb630575980f5eb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 3 Feb 2025 16:01:26 -0800 Subject: [PATCH 02/12] tmp --- examples/agents/react_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index 48fd8926..4cf634fb 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -159,8 +159,8 @@ def get_name(self) -> str: def get_description(self) -> str: return """ -Search the web for the given query about Meta. Get information Meta available on the public internet -""" + Search the web for the given query about Meta. Get information Meta available on the public internet + """ def get_params_definition(self) -> Dict[str, Parameter]: return { From 1a493a948d57387f3b78360164c68468d61ed1f6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 11:19:53 -0800 Subject: [PATCH 03/12] example --- examples/agents/react_agent.py | 68 ++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index 4cf634fb..c0e890e9 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -4,17 +4,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. # import os +import json +import re import uuid -from typing import Dict, List, Union +from typing import Any, Dict, List, Optional, Tuple, Union import fire from llama_stack_client import LlamaStackClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool +from llama_stack_client.lib.agents.output_parser import OutputParser # from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.turn import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage from llama_stack_client.types.shared.user_message import UserMessage from llama_stack_client.types.tool_def_param import Parameter @@ -152,6 +157,63 @@ """ +class ReActOutputParser(OutputParser): + def maybe_extract_action(self, text: str) -> Optional[Tuple[str, Dict[str, Any]]]: + """ + Extract action name and parameters from the text format: + + Thought: + + Action: + { + "action": , + "action_input": + } + + Args: + text (str): Input text containing the action block + + Returns: + Tuple[str, Dict[str, Any]]: Tuple of (action_name, action_parameters) + + Raises: + ValueError: If the action block cannot be parsed or is missing required fields + """ + try: + # Find the action block using regex + action_pattern = r'Action:\s*{\s*"action":\s*"([^"]+)",\s*"action_input":\s*({[^}]+})\s*}' + match = re.search(action_pattern, text, re.DOTALL) + + if not match: + raise ValueError("Could not find valid action block in text") + + action_name = match.group(1) + action_params = json.loads(match.group(2)) + + return action_name, action_params + except (ValueError, json.JSONDecodeError) as e: + print(f"Error parsing action: {e}") + return None + + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + action = self._maybe_extract_action(output_message.content) + if action is None: + return output_message + + action_name, action_params = action + call_id = str(uuid.uuid4()) + return CompletionMessage( + content=output_message.content, + tool_calls=[ + ToolCall( + call_id=call_id, + tool_name=action_name, + arguments=action_params, + ) + ], + ) + + class MetaExternalSearchTool(ClientTool): def get_name(self) -> str: @@ -175,7 +237,7 @@ def get_params_definition(self) -> Dict[str, Parameter]: def run( self, messages: List[Union[UserMessage, ToolResponseMessage]] ) -> List[Union[UserMessage, ToolResponseMessage]]: - print("run_impl for MetaExternalSearchTool called") + print("run for MetaExternalSearchTool called") dummy_response = """ torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs. @@ -231,7 +293,7 @@ def main(): output_shields=[], enable_session_persistence=False, ) - agent = Agent(client, agent_config, client_tools) + agent = Agent(client, agent_config, client_tools, output_parser=ReActOutputParser()) session_id = agent.create_session(f"ttest-session-{uuid.uuid4().hex}") From e28b16ee3af75e90f795b6deae65bc40828875ae Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 11:33:25 -0800 Subject: [PATCH 04/12] update example --- examples/agents/react_agent.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index c0e890e9..fc9d4664 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -158,7 +158,7 @@ class ReActOutputParser(OutputParser): - def maybe_extract_action(self, text: str) -> Optional[Tuple[str, Dict[str, Any]]]: + def _maybe_extract_action(self, text: str) -> Optional[Tuple[str, Dict[str, Any]]]: """ Extract action name and parameters from the text format: @@ -196,22 +196,21 @@ def maybe_extract_action(self, text: str) -> Optional[Tuple[str, Dict[str, Any]] return None def parse(self, output_message: CompletionMessage) -> CompletionMessage: - action = self._maybe_extract_action(output_message.content) - if action is None: + text = str(output_message.content) + action = self._maybe_extract_action(text) + if action is None or action[0] == "final_answer": return output_message action_name, action_params = action call_id = str(uuid.uuid4()) - return CompletionMessage( - content=output_message.content, - tool_calls=[ - ToolCall( - call_id=call_id, - tool_name=action_name, - arguments=action_params, - ) - ], - ) + output_message.tool_calls = [ + ToolCall( + call_id=call_id, + tool_name=action_name, + arguments=action_params, + ) + ] + return output_message class MetaExternalSearchTool(ClientTool): @@ -306,9 +305,6 @@ def main(): ) pprint(response) - # for chunk in response: - # pprint(chunk) - if __name__ == "__main__": fire.Fire(main) From a8ac2f539687cc0014c4850cb898ccfc38fe5620 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 19:31:12 -0800 Subject: [PATCH 05/12] refactor --- examples/agents/react_agent.py | 259 +++++++++++++++++---------------- 1 file changed, 130 insertions(+), 129 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index fc9d4664..48dcc019 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -5,9 +5,8 @@ # the root directory of this source tree. # import os import json -import re import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import fire @@ -25,130 +24,140 @@ from llama_stack_client.types.tool_def_param import Parameter from rich.pretty import pprint -REACT_PROMPT = """ +REACT_JSON_PROMPT = """ You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. To do so, you have been given access to the following tools: <> -The way you use the tools is by specifying a json blob, ending with ''. -Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool). -The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB: +You must always respond in the following JSON format: { - "action": $TOOL_NAME, - "action_input": $INPUT -} + "thought": $THOUGHT_PROCESS, + "action": { + "tool_name": $TOOL_NAME, + "tool_params": $TOOL_PARAMS + }, + "observation": $OBSERVATION, + "answer": $ANSWER +} + +Specifically, this json should have a `thought` key, a `action` key, and an `observation` key. -Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. +The `action` key should specify the $TOOL_NAME the name of the tool to use and the `tool_params` key should specify the parameters key as input to the tool. -You should ALWAYS use the following format: +Make sure to have the $TOOL_PARAMS as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. -Thought: you should always think about one action to take. Then use the action as follows: -Action: -$ACTION_JSON_BLOB -Observation: the result of the action -... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The $ACTION_JSON_BLOB must only use a SINGLE action at a time.) +You should always think about one action to take, and have the `thought` key contain your thought process about this action. +The `observation` key should contain the result of the action +... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The action key must only use a SINGLE tool at a time.) You can use the result of the previous action as input for the next action. The observation will always be a string: it can represent a file, like "image_1.jpg". Then you can use it as input for the next action. You can do it for instance as follows: -Observation: "image_1.jpg" - -Thought: I need to transform the image that I received in the previous observation to make it green. -Action: { - "action": "image_transformer", - "action_input": {"image": "image_1.jpg"} -} + "observation": "image_1.jpg", + "thought": "I need to transform the image that I received in the previous observation to make it green.", + "action": { + "tool_name": "image_transformer", + "tool_params": {"image": "image_1.jpg"} + }, + "answer": null +} -To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: -Action: -{ - "action": "final_answer", - "action_input": {"answer": "insert your final answer here"} -} +To provide the final answer to the task, use the `answer` key. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: +{ + "observation": "your observation", + "thought": "you thought process", + "action": null, + "answer": "insert your final answer here" +} Here are a few examples using notional tools: --- Task: "Generate an image of the oldest person in this document." -Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. -Action: { - "action": "document_qa", - "action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} -} - -Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." - + "thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.", + "action": { + "tool_name": "document_qa", + "tool_params": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} + }, + "observation": "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland.", + "answer": null +} -Thought: I will now generate an image showcasing the oldest person. -Action: { - "action": "image_generator", - "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} -} -Observation: "image.png" + "thought": "I will now generate an image showcasing the oldest person.", + "action": { + "tool_name": "image_generator", + "tool_params": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} + }, + "observation": "image.png", + "answer": null +} -Thought: I will now return the generated image. -Action: { - "action": "final_answer", - "action_input": "image.png" -} + "thought": "I will now return the generated image.", + "action": null, + "answer": "image.png" +} --- Task: "What is the result of the following operation: 5 + 3 + 1294.678?" -Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool -Action: { - "action": "python_interpreter", - "action_input": {"code": "5 + 3 + 1294.678"} -} -Observation: 1302.678 + "thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool", + "action": { + "tool_name": "python_interpreter", + "tool_params": {"code": "5 + 3 + 1294.678"} + }, + "observation": 1302.678, + "answer": null +} -Thought: Now that I know the result, I will now return it. -Action: { - "action": "final_answer", - "action_input": "1302.678" -} + "thought": "Now that I know the result, I will now return it.", + "action": null, + "observation": null, + "answer": 1302.678 +} --- Task: "Which city has the highest population , Guangzhou or Shanghai?" -Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. -Action: { - "action": "search", - "action_input": "Population Guangzhou" -} -Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] - + "thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.", + "action": { + "tool_name": "search", + "tool_params": {"query": "Population Guangzhou"} + }, + "observation": ['Guangzhou has a population of 15 million inhabitants as of 2021.'], + "answer": null +} -Thought: Now let's get the population of Shanghai using the tool 'search'. -Action: { - "action": "search", - "action_input": "Population Shanghai" + "thought": "Now let's get the population of Shanghai using the tool 'search'.", + "action": { + "tool_name": "search", + "tool_params": {"query": "Population Shanghai"} + }, + "observation": "26 million (2019)", + "answer": null } -Observation: '26 million (2019)' -Thought: Now I know that Shanghai has a larger population. Let's return the result. -Action: { - "action": "final_answer", - "action_input": "Shanghai" -} - + "thought": "Now I know that Shanghai has a larger population. Let's return the result.", + "action": null, + "observation": null, + "answer": "Shanghai" +} Above example were using notional tools that might not exist for you. You only have access to these tools: <> Here are the rules you should always follow to solve your task: -1. ALWAYS provide a single 'Thought:' sequence, and a single 'Action:' sequence that ends with , else you will fail. -2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead. +1. ALWAYS answer in the JSON format with keys "observation", "thought", "action", "answer", else you will fail. +2. Always use the right arguments for the tools. Never use variable names in the 'tool_params' field, use the value instead. 3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. 4. Never re-do a tool call that you previously did with the exact same parameters. 5. Observations will be provided to you, no need to generate them @@ -156,71 +165,55 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ +from pydantic import BaseModel -class ReActOutputParser(OutputParser): - def _maybe_extract_action(self, text: str) -> Optional[Tuple[str, Dict[str, Any]]]: - """ - Extract action name and parameters from the text format: - Thought: +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] - Action: - { - "action": , - "action_input": - } - Args: - text (str): Input text containing the action block +class ReActOutput(BaseModel): + thought: str + action: Optional[Action] = None + observation: Optional[str] = None + answer: Optional[str] = None - Returns: - Tuple[str, Dict[str, Any]]: Tuple of (action_name, action_parameters) - Raises: - ValueError: If the action block cannot be parsed or is missing required fields - """ +class ReActOutputParser(OutputParser): + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + response_text = str(output_message.content) try: - # Find the action block using regex - action_pattern = r'Action:\s*{\s*"action":\s*"([^"]+)",\s*"action_input":\s*({[^}]+})\s*}' - match = re.search(action_pattern, text, re.DOTALL) - - if not match: - raise ValueError("Could not find valid action block in text") - - action_name = match.group(1) - action_params = json.loads(match.group(2)) - - return action_name, action_params - except (ValueError, json.JSONDecodeError) as e: + response_json = json.loads(response_text) + except json.JSONDecodeError as e: print(f"Error parsing action: {e}") - return None + return output_message - def parse(self, output_message: CompletionMessage) -> CompletionMessage: - text = str(output_message.content) - action = self._maybe_extract_action(text) - if action is None or action[0] == "final_answer": + if response_json.get("answer", None): return output_message - action_name, action_params = action - call_id = str(uuid.uuid4()) - output_message.tool_calls = [ - ToolCall( - call_id=call_id, - tool_name=action_name, - arguments=action_params, - ) - ] + if response_json.get("action", None): + tool_name = response_json["action"].get("tool_name", None) + tool_params = response_json["action"].get("tool_params", None) + if tool_name and tool_params: + call_id = str(uuid.uuid4()) + output_message.tool_calls = [ + ToolCall( + call_id=call_id, tool_name=tool_name, arguments=tool_params + ) + ] + return output_message -class MetaExternalSearchTool(ClientTool): +class SearchTool(ClientTool): def get_name(self) -> str: - return "get_external_meta_data" + return "search" def get_description(self) -> str: return """ - Search the web for the given query about Meta. Get information Meta available on the public internet + Search the web for the given query. """ def get_params_definition(self) -> Dict[str, Parameter]: @@ -236,6 +229,9 @@ def get_params_definition(self) -> Dict[str, Parameter]: def run( self, messages: List[Union[UserMessage, ToolResponseMessage]] ) -> List[Union[UserMessage, ToolResponseMessage]]: + from rich.pretty import pprint + + pprint(messages) print("run for MetaExternalSearchTool called") dummy_response = """ torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs. @@ -263,17 +259,17 @@ def main(): base_url="http://localhost:8321", ) - model = "meta-llama/Llama-3.3-70B-Instruct" + model = "meta-llama/Llama-3.1-8B-Instruct" client_tools = [ - MetaExternalSearchTool(), + SearchTool(), ] tool_names = ", ".join([tool.get_name() for tool in client_tools]) tool_descriptions = "\n".join( [f"- {tool.get_name()}: {tool.get_description()}" for tool in client_tools] ) - instruction = REACT_PROMPT.replace("<>", tool_names).replace( + instruction = REACT_JSON_PROMPT.replace("<>", tool_names).replace( "<>", tool_descriptions ) @@ -287,11 +283,16 @@ def main(): client_tool.get_tool_definition() for client_tool in client_tools ], tool_choice="auto", - tool_prompt_format="python_list", + tool_prompt_format="json", input_shields=[], output_shields=[], enable_session_persistence=False, + # response_format={ + # "type": "json_schema", + # "json_schema": ReActOutput.model_json_schema(), + # }, ) + agent = Agent(client, agent_config, client_tools, output_parser=ReActOutputParser()) session_id = agent.create_session(f"ttest-session-{uuid.uuid4().hex}") From c205e6a138bb3ea96b60d77abca9100c05f3d401 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 10:31:06 -0800 Subject: [PATCH 06/12] use sdk --- examples/agents/react_agent.py | 153 ++------------------------------- 1 file changed, 7 insertions(+), 146 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index 48dcc019..2d191024 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -14,6 +14,9 @@ from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.output_parser import OutputParser +from llama_stack_client.lib.agents.react.prompts import ( + DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE, +) # from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types.agent_create_params import AgentConfig @@ -22,150 +25,9 @@ from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage from llama_stack_client.types.shared.user_message import UserMessage from llama_stack_client.types.tool_def_param import Parameter -from rich.pretty import pprint - -REACT_JSON_PROMPT = """ -You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. -To do so, you have been given access to the following tools: <> - -You must always respond in the following JSON format: -{ - "thought": $THOUGHT_PROCESS, - "action": { - "tool_name": $TOOL_NAME, - "tool_params": $TOOL_PARAMS - }, - "observation": $OBSERVATION, - "answer": $ANSWER -} - -Specifically, this json should have a `thought` key, a `action` key, and an `observation` key. - -The `action` key should specify the $TOOL_NAME the name of the tool to use and the `tool_params` key should specify the parameters key as input to the tool. - -Make sure to have the $TOOL_PARAMS as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. - -You should always think about one action to take, and have the `thought` key contain your thought process about this action. -The `observation` key should contain the result of the action -... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The action key must only use a SINGLE tool at a time.) - -You can use the result of the previous action as input for the next action. -The observation will always be a string: it can represent a file, like "image_1.jpg". -Then you can use it as input for the next action. You can do it for instance as follows: - -{ - "observation": "image_1.jpg", - "thought": "I need to transform the image that I received in the previous observation to make it green.", - "action": { - "tool_name": "image_transformer", - "tool_params": {"image": "image_1.jpg"} - }, - "answer": null -} - - -To provide the final answer to the task, use the `answer` key. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: -{ - "observation": "your observation", - "thought": "you thought process", - "action": null, - "answer": "insert your final answer here" -} - -Here are a few examples using notional tools: ---- -Task: "Generate an image of the oldest person in this document." - -{ - "thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.", - "action": { - "tool_name": "document_qa", - "tool_params": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} - }, - "observation": "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland.", - "answer": null -} - -{ - "thought": "I will now generate an image showcasing the oldest person.", - "action": { - "tool_name": "image_generator", - "tool_params": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} - }, - "observation": "image.png", - "answer": null -} - -{ - "thought": "I will now return the generated image.", - "action": null, - "answer": "image.png" -} - ---- -Task: "What is the result of the following operation: 5 + 3 + 1294.678?" - -{ - "thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool", - "action": { - "tool_name": "python_interpreter", - "tool_params": {"code": "5 + 3 + 1294.678"} - }, - "observation": 1302.678, - "answer": null -} - -{ - "thought": "Now that I know the result, I will now return it.", - "action": null, - "observation": null, - "answer": 1302.678 -} - ---- -Task: "Which city has the highest population , Guangzhou or Shanghai?" - -{ - "thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.", - "action": { - "tool_name": "search", - "tool_params": {"query": "Population Guangzhou"} - }, - "observation": ['Guangzhou has a population of 15 million inhabitants as of 2021.'], - "answer": null -} - -{ - "thought": "Now let's get the population of Shanghai using the tool 'search'.", - "action": { - "tool_name": "search", - "tool_params": {"query": "Population Shanghai"} - }, - "observation": "26 million (2019)", - "answer": null -} - -{ - "thought": "Now I know that Shanghai has a larger population. Let's return the result.", - "action": null, - "observation": null, - "answer": "Shanghai" -} - -Above example were using notional tools that might not exist for you. You only have access to these tools: -<> - -Here are the rules you should always follow to solve your task: -1. ALWAYS answer in the JSON format with keys "observation", "thought", "action", "answer", else you will fail. -2. Always use the right arguments for the tools. Never use variable names in the 'tool_params' field, use the value instead. -3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. -4. Never re-do a tool call that you previously did with the exact same parameters. -5. Observations will be provided to you, no need to generate them - -Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. -""" from pydantic import BaseModel +from rich.pretty import pprint class Action(BaseModel): @@ -176,7 +38,6 @@ class Action(BaseModel): class ReActOutput(BaseModel): thought: str action: Optional[Action] = None - observation: Optional[str] = None answer: Optional[str] = None @@ -269,9 +130,9 @@ def main(): tool_descriptions = "\n".join( [f"- {tool.get_name()}: {tool.get_description()}" for tool in client_tools] ) - instruction = REACT_JSON_PROMPT.replace("<>", tool_names).replace( - "<>", tool_descriptions - ) + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( + "<>", tool_names + ).replace("<>", tool_descriptions) agent_config = AgentConfig( model=model, From 9364c5b2dcd8f5fba1c5a86580de23e80f7b501b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 10:35:50 -0800 Subject: [PATCH 07/12] use sdk --- examples/agents/react_agent.py | 39 ++-------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index 2d191024..ed4b0884 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -3,8 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# import os -import json import uuid from typing import Any, Dict, List, Optional, Union @@ -13,15 +11,12 @@ from llama_stack_client import LlamaStackClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool -from llama_stack_client.lib.agents.output_parser import OutputParser +from llama_stack_client.lib.agents.react.output_parser import ReActOutputParser from llama_stack_client.lib.agents.react.prompts import ( DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE, ) -# from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.agents.turn import CompletionMessage -from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage from llama_stack_client.types.shared.user_message import UserMessage from llama_stack_client.types.tool_def_param import Parameter @@ -41,32 +36,6 @@ class ReActOutput(BaseModel): answer: Optional[str] = None -class ReActOutputParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> CompletionMessage: - response_text = str(output_message.content) - try: - response_json = json.loads(response_text) - except json.JSONDecodeError as e: - print(f"Error parsing action: {e}") - return output_message - - if response_json.get("answer", None): - return output_message - - if response_json.get("action", None): - tool_name = response_json["action"].get("tool_name", None) - tool_params = response_json["action"].get("tool_params", None) - if tool_name and tool_params: - call_id = str(uuid.uuid4()) - output_message.tool_calls = [ - ToolCall( - call_id=call_id, tool_name=tool_name, arguments=tool_params - ) - ] - - return output_message - - class SearchTool(ClientTool): def get_name(self) -> str: @@ -90,10 +59,6 @@ def get_params_definition(self) -> Dict[str, Parameter]: def run( self, messages: List[Union[UserMessage, ToolResponseMessage]] ) -> List[Union[UserMessage, ToolResponseMessage]]: - from rich.pretty import pprint - - pprint(messages) - print("run for MetaExternalSearchTool called") dummy_response = """ torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs. @@ -107,7 +72,7 @@ def run( """ return [ ToolResponseMessage( - call_id="random-id", + call_id=messages[0].tool_calls[0].call_id, tool_name=self.get_name(), content=dummy_response, role="tool", From ea4cf429558a1a9127940fe9ecb21c656e21f8da Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 12:04:07 -0800 Subject: [PATCH 08/12] wip builtin tool --- examples/agents/hello.py | 4 +- examples/agents/react_agent.py | 70 +++++++++++++++++++++++++++++----- 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/examples/agents/hello.py b/examples/agents/hello.py index ca47663e..a2c1d234 100644 --- a/examples/agents/hello.py +++ b/examples/agents/hello.py @@ -33,7 +33,9 @@ def main(host: str, port: int): print(f"Available shields found: {available_shields}") available_models = [ - model.identifier for model in client.models.list() if model.model_type == "llm" + model.identifier + for model in client.models.list() + if model.model_type == "llm" and "405B" not in model.identifier ] if not available_models: print(colored("No available models. Exiting.", "red")) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index ed4b0884..698c58f0 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -11,6 +11,7 @@ from llama_stack_client import LlamaStackClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool +from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.react.output_parser import ReActOutputParser from llama_stack_client.lib.agents.react.prompts import ( DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE, @@ -22,6 +23,7 @@ from llama_stack_client.types.tool_def_param import Parameter from pydantic import BaseModel + from rich.pretty import pprint @@ -91,23 +93,72 @@ def main(): SearchTool(), ] - tool_names = ", ".join([tool.get_name() for tool in client_tools]) + builtin_toolgroups = [ + "builtin::websearch", + ] + + pprint(client.toolgroups.list()) + + pprint(client.tools.list(toolgroup_id="builtin::websearch")) + + # BUILTIN TOOLS + def get_tool_definition(tool): + return { + "name": tool.identifier, + "description": tool.description, + "parameters": tool.parameters, + } + + tool_names = ", ".join( + [ + tool.identifier + for tool in client.tools.list(toolgroup_id="builtin::websearch") + ] + ) tool_descriptions = "\n".join( - [f"- {tool.get_name()}: {tool.get_description()}" for tool in client_tools] + [ + f"- {tool.identifier}: {get_tool_definition(tool)}" + for tool in client.tools.list(toolgroup_id="builtin::websearch") + ] ) + + print(tool_names) + print(tool_descriptions) + + # pprint( + # client.tool_runtime.invoke_tool( + # tool_name="web_search", + # kwargs={"query": "Current time in New York"}, + # ) + # ) + # exit(1) + # instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( + # "<>", tool_names + # ).replace("<>", tool_descriptions) + + # CLIENT TOOLS + # tool_names = ", ".join([tool.get_name() for tool in client_tools]) + # tool_descriptions = "\n".join( + # [f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools] + # ) instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( "<>", tool_names ).replace("<>", tool_descriptions) + # print(tool_names) + # print(tool_descriptions) + # instruction = "you are a helpful assistant" + agent_config = AgentConfig( model=model, instructions=instruction, sampling_params={ "strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9}, }, - client_tools=[ - client_tool.get_tool_definition() for client_tool in client_tools - ], + toolgroups=["builtin::websearch"], + # client_tools=[ + # client_tool.get_tool_definition() for client_tool in client_tools + # ], tool_choice="auto", tool_prompt_format="json", input_shields=[], @@ -124,13 +175,12 @@ def main(): session_id = agent.create_session(f"ttest-session-{uuid.uuid4().hex}") response = agent.create_turn( - messages=[ - {"role": "user", "content": "What model families does torchtune support?"} - ], + messages=[{"role": "user", "content": "What's the current time in new york?"}], session_id=session_id, - stream=False, + stream=True, ) - pprint(response) + for log in EventLogger().log(response): + log.print() if __name__ == "__main__": From 3f81f876fa20afcd02118393d462bbbcd79356e2 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 12:45:49 -0800 Subject: [PATCH 09/12] enable builtin tool --- examples/agents/react_agent.py | 83 +++++++++++++++------------------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index 698c58f0..0f313bf3 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -24,8 +24,6 @@ from pydantic import BaseModel -from rich.pretty import pprint - class Action(BaseModel): tool_name: str @@ -38,14 +36,14 @@ class ReActOutput(BaseModel): answer: Optional[str] = None -class SearchTool(ClientTool): +class TorchtuneTool(ClientTool): def get_name(self) -> str: - return "search" + return "torchtune" def get_description(self) -> str: return """ - Search the web for the given query. + Answer information about torchtune. """ def get_params_definition(self) -> Dict[str, Parameter]: @@ -90,17 +88,13 @@ def main(): model = "meta-llama/Llama-3.1-8B-Instruct" client_tools = [ - SearchTool(), + TorchtuneTool(), ] builtin_toolgroups = [ "builtin::websearch", ] - pprint(client.toolgroups.list()) - - pprint(client.tools.list(toolgroup_id="builtin::websearch")) - # BUILTIN TOOLS def get_tool_definition(tool): return { @@ -109,45 +103,32 @@ def get_tool_definition(tool): "parameters": tool.parameters, } - tool_names = ", ".join( - [ - tool.identifier - for tool in client.tools.list(toolgroup_id="builtin::websearch") - ] - ) - tool_descriptions = "\n".join( - [ - f"- {tool.identifier}: {get_tool_definition(tool)}" - for tool in client.tools.list(toolgroup_id="builtin::websearch") - ] - ) - - print(tool_names) - print(tool_descriptions) - - # pprint( - # client.tool_runtime.invoke_tool( - # tool_name="web_search", - # kwargs={"query": "Current time in New York"}, - # ) - # ) - # exit(1) - # instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( - # "<>", tool_names - # ).replace("<>", tool_descriptions) + tool_names = "" + tool_descriptions = "" + for x in builtin_toolgroups: + tool_names += ", ".join( + [tool.identifier for tool in client.tools.list(toolgroup_id=x)] + ) + tool_descriptions += "\n".join( + [ + f"- {tool.identifier}: {get_tool_definition(tool)}" + for tool in client.tools.list(toolgroup_id=x) + ] + ) # CLIENT TOOLS - # tool_names = ", ".join([tool.get_name() for tool in client_tools]) - # tool_descriptions = "\n".join( - # [f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools] - # ) + tool_names += ", " + tool_descriptions += "\n" + tool_names += ", ".join([tool.get_name() for tool in client_tools]) + tool_descriptions += "\n".join( + [f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools] + ) instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( "<>", tool_names ).replace("<>", tool_descriptions) - # print(tool_names) - # print(tool_descriptions) - # instruction = "you are a helpful assistant" + print(tool_names) + print(tool_descriptions) agent_config = AgentConfig( model=model, @@ -155,10 +136,10 @@ def get_tool_definition(tool): sampling_params={ "strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9}, }, - toolgroups=["builtin::websearch"], - # client_tools=[ - # client_tool.get_tool_definition() for client_tool in client_tools - # ], + toolgroups=builtin_toolgroups, + client_tools=[ + client_tool.get_tool_definition() for client_tool in client_tools + ], tool_choice="auto", tool_prompt_format="json", input_shields=[], @@ -182,6 +163,14 @@ def get_tool_definition(tool): for log in EventLogger().log(response): log.print() + response2 = agent.create_turn( + messages=[{"role": "user", "content": "What is torchtune?"}], + session_id=session_id, + stream=True, + ) + for log in EventLogger().log(response2): + log.print() + if __name__ == "__main__": fire.Fire(main) From 695a192d7dc7cbc65567ffc1d17b019e70b56642 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 13:05:22 -0800 Subject: [PATCH 10/12] use sdk react --- examples/agents/react_agent.py | 90 +++------------------------------- 1 file changed, 6 insertions(+), 84 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index 0f313bf3..e8fa50a7 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -4,37 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import uuid -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union import fire from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.lib.agents.react.output_parser import ReActOutputParser -from llama_stack_client.lib.agents.react.prompts import ( - DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE, -) - -from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.lib.agents.react.agent import ReActAgent from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage from llama_stack_client.types.shared.user_message import UserMessage from llama_stack_client.types.tool_def_param import Parameter -from pydantic import BaseModel - - -class Action(BaseModel): - tool_name: str - tool_params: Dict[str, Any] - - -class ReActOutput(BaseModel): - thought: str - action: Optional[Action] = None - answer: Optional[str] = None - class TorchtuneTool(ClientTool): @@ -87,72 +68,13 @@ def main(): model = "meta-llama/Llama-3.1-8B-Instruct" - client_tools = [ - TorchtuneTool(), - ] - - builtin_toolgroups = [ - "builtin::websearch", - ] - - # BUILTIN TOOLS - def get_tool_definition(tool): - return { - "name": tool.identifier, - "description": tool.description, - "parameters": tool.parameters, - } - - tool_names = "" - tool_descriptions = "" - for x in builtin_toolgroups: - tool_names += ", ".join( - [tool.identifier for tool in client.tools.list(toolgroup_id=x)] - ) - tool_descriptions += "\n".join( - [ - f"- {tool.identifier}: {get_tool_definition(tool)}" - for tool in client.tools.list(toolgroup_id=x) - ] - ) - - # CLIENT TOOLS - tool_names += ", " - tool_descriptions += "\n" - tool_names += ", ".join([tool.get_name() for tool in client_tools]) - tool_descriptions += "\n".join( - [f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools] - ) - instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( - "<>", tool_names - ).replace("<>", tool_descriptions) - - print(tool_names) - print(tool_descriptions) - - agent_config = AgentConfig( + agent = ReActAgent( + client=client, model=model, - instructions=instruction, - sampling_params={ - "strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9}, - }, - toolgroups=builtin_toolgroups, - client_tools=[ - client_tool.get_tool_definition() for client_tool in client_tools - ], - tool_choice="auto", - tool_prompt_format="json", - input_shields=[], - output_shields=[], - enable_session_persistence=False, - # response_format={ - # "type": "json_schema", - # "json_schema": ReActOutput.model_json_schema(), - # }, + builtin_toolgroups=["builtin::websearch"], + client_tools=[TorchtuneTool()], ) - agent = Agent(client, agent_config, client_tools, output_parser=ReActOutputParser()) - session_id = agent.create_session(f"ttest-session-{uuid.uuid4().hex}") response = agent.create_turn( From f9558081da5f04f4c071f05f4503918f8e892b9c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:07:35 -0800 Subject: [PATCH 11/12] tool decorator working --- examples/agents/react_agent.py | 52 +++++++++------------------------- 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index e8fa50a7..bf1f6b8e 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -4,43 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import uuid -from typing import Dict, List, Union import fire from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.agents.client_tool import ClientTool +from llama_stack_client.lib.agents.client_tool import tool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.react.agent import ReActAgent -from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage -from llama_stack_client.types.shared.user_message import UserMessage -from llama_stack_client.types.tool_def_param import Parameter -class TorchtuneTool(ClientTool): +@tool +def torchtune(query: str = "torchtune"): + """ + Answer information about torchtune. - def get_name(self) -> str: - return "torchtune" - - def get_description(self) -> str: - return """ - Answer information about torchtune. - """ - - def get_params_definition(self) -> Dict[str, Parameter]: - return { - "query": Parameter( - name="query", - parameter_type="str", - description="The query to use for querying the internet", - required=True, - ) - } - - def run( - self, messages: List[Union[UserMessage, ToolResponseMessage]] - ) -> List[Union[UserMessage, ToolResponseMessage]]: - dummy_response = """ + :param query: The query to use for querying the internet + :returns: Information about torchtune + """ + dummy_response = """ torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs. torchtune provides: @@ -50,15 +31,8 @@ def run( Out-of-the-box memory efficiency, performance improvements, and scaling with the latest PyTorch APIs YAML configs for easily configuring training, evaluation, quantization or inference recipes Built-in support for many popular dataset formats and prompt templates - """ - return [ - ToolResponseMessage( - call_id=messages[0].tool_calls[0].call_id, - tool_name=self.get_name(), - content=dummy_response, - role="tool", - ) - ] + """ + return dummy_response def main(): @@ -67,12 +41,12 @@ def main(): ) model = "meta-llama/Llama-3.1-8B-Instruct" - + print(type(torchtune)) agent = ReActAgent( client=client, model=model, builtin_toolgroups=["builtin::websearch"], - client_tools=[TorchtuneTool()], + client_tools=[torchtune], ) session_id = agent.create_session(f"ttest-session-{uuid.uuid4().hex}") From fc4fae71df9708f74e3ba3c9a3a475d0d98c7930 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:39:55 -0800 Subject: [PATCH 12/12] update agent --- examples/agents/react_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/agents/react_agent.py b/examples/agents/react_agent.py index bf1f6b8e..992a8ed7 100644 --- a/examples/agents/react_agent.py +++ b/examples/agents/react_agent.py @@ -8,12 +8,12 @@ import fire from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.agents.client_tool import tool +from llama_stack_client.lib.agents.client_tool import client_tool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.react.agent import ReActAgent -@tool +@client_tool def torchtune(query: str = "torchtune"): """ Answer information about torchtune.