From 30769779375edbefa8c1d70324546d6aa190403a Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Wed, 26 Feb 2025 12:05:08 +0100 Subject: [PATCH 1/6] adding the 1st configured vector_db_id, if any Signed-off-by: Daniele Martinoli --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2a93e7b3f5..cf3cc8ef7d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -530,8 +530,6 @@ async def _run( toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) - if documents: - await self.handle_documents(session_id, documents, input_messages, tool_defs) output_attachments = [] From aa546de8d631fd2ac4fc2397c74d7e747f86dfee Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Fri, 28 Feb 2025 08:32:02 +0100 Subject: [PATCH 2/6] renamed insert_vector_db_id to documents_db_id, removed vector_db_id from session info Signed-off-by: Daniele Martinoli --- docs/source/building_applications/rag.md | 11 +++++-- .../agents/meta_reference/agent_instance.py | 30 +++---------------- .../agents/meta_reference/persistence.py | 11 ------- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index e2e5fd6b51..08d68fefac 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -93,7 +93,14 @@ agent_config = AgentConfig( { "name": "builtin::rag/knowledge_search", "args": { - "vector_db_ids": [vector_db_id], + # 'documents_db_id' holds the ID of the registered vector database + # where the provided documents will be ingested. This argument is mandatory + # when the 'documents' parameter is provided in a 'create_turn' invocation. + # When provided, 'documents_db_id' will also be used to extract contextual information + # for the query. + "documents_db_id": vector_db_id, + # Optionally, the 'vector_db_ids' argument can specify additional vector databases + # to use at query time. }, } ], @@ -109,7 +116,7 @@ response = agent.create_turn( ], documents=[ { - "content": "https://raw.githubusercontent.com/example/doc.rst", + "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst", "mime_type": "text/plain", } ], diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cf3cc8ef7d..816d21f584 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -911,7 +911,7 @@ def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tup async def handle_documents( self, - session_id: str, + documents_db_id: str, documents: List[Document], input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], @@ -943,7 +943,7 @@ async def handle_documents( msg = await attachment_message(self.tempdir, url_items) input_messages.append(msg) # Since memory is present, add all the data to the memory bank - await self.add_to_session_vector_db(session_id, documents) + await self.add_to_session_vector_db(documents_db_id, documents) elif code_interpreter_tool: # if only code_interpreter is available, we download the URLs to a tempdir # and attach the path to them as a message to inference with the @@ -952,7 +952,7 @@ async def handle_documents( input_messages.append(msg) elif memory_tool: # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_vector_db(session_id, documents) + await self.add_to_session_vector_db(documents_db_id, documents) else: # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference @@ -961,29 +961,7 @@ async def handle_documents( [doc.content for doc in content_items] + await load_data_from_urls(url_items) ) - async def _ensure_vector_db(self, session_id: str) -> str: - session_info = await self.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - if session_info.vector_db_id is None: - vector_db_id = f"vector_db_{session_id}" - - # TODO: the semantic for registration is definitely not "creation" - # so we need to fix it if we expect the agent to create a new vector db - # for each session - await self.vector_io_api.register_vector_db( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - ) - await self.storage.add_vector_db_to_session(session_id, vector_db_id) - else: - vector_db_id = session_info.vector_db_id - - return vector_db_id - - async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None: - vector_db_id = await self._ensure_vector_db(session_id) + async def add_to_session_vector_db(self, vector_db_id: str, data: List[Document]) -> None: documents = [ RAGDocument( document_id=str(uuid.uuid4()), diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 9a26635677..a8d5ce69ee 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -53,17 +53,6 @@ async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: return AgentSessionInfo(**json.loads(value)) - async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): - session_info = await self.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - session_info.vector_db_id = vector_db_id - await self.kvstore.set( - key=f"session:{self.agent_id}:{session_id}", - value=session_info.model_dump_json(), - ) - async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", From 5ca575eefe4c6a83dcca22a60a5e05d0d7e6bb89 Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Fri, 28 Feb 2025 09:47:47 +0100 Subject: [PATCH 3/6] restored from upstream Signed-off-by: Daniele Martinoli --- .../agents/meta_reference/agent_instance.py | 32 ++++++++++++++++--- .../agents/meta_reference/persistence.py | 11 +++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 816d21f584..2a93e7b3f5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -530,6 +530,8 @@ async def _run( toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) + if documents: + await self.handle_documents(session_id, documents, input_messages, tool_defs) output_attachments = [] @@ -911,7 +913,7 @@ def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tup async def handle_documents( self, - documents_db_id: str, + session_id: str, documents: List[Document], input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], @@ -943,7 +945,7 @@ async def handle_documents( msg = await attachment_message(self.tempdir, url_items) input_messages.append(msg) # Since memory is present, add all the data to the memory bank - await self.add_to_session_vector_db(documents_db_id, documents) + await self.add_to_session_vector_db(session_id, documents) elif code_interpreter_tool: # if only code_interpreter is available, we download the URLs to a tempdir # and attach the path to them as a message to inference with the @@ -952,7 +954,7 @@ async def handle_documents( input_messages.append(msg) elif memory_tool: # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_vector_db(documents_db_id, documents) + await self.add_to_session_vector_db(session_id, documents) else: # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference @@ -961,7 +963,29 @@ async def handle_documents( [doc.content for doc in content_items] + await load_data_from_urls(url_items) ) - async def add_to_session_vector_db(self, vector_db_id: str, data: List[Document]) -> None: + async def _ensure_vector_db(self, session_id: str) -> str: + session_info = await self.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + if session_info.vector_db_id is None: + vector_db_id = f"vector_db_{session_id}" + + # TODO: the semantic for registration is definitely not "creation" + # so we need to fix it if we expect the agent to create a new vector db + # for each session + await self.vector_io_api.register_vector_db( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + ) + await self.storage.add_vector_db_to_session(session_id, vector_db_id) + else: + vector_db_id = session_info.vector_db_id + + return vector_db_id + + async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None: + vector_db_id = await self._ensure_vector_db(session_id) documents = [ RAGDocument( document_id=str(uuid.uuid4()), diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index a8d5ce69ee..9a26635677 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -53,6 +53,17 @@ async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: return AgentSessionInfo(**json.loads(value)) + async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): + session_info = await self.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + session_info.vector_db_id = vector_db_id + await self.kvstore.set( + key=f"session:{self.agent_id}:{session_id}", + value=session_info.model_dump_json(), + ) + async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", From 1181754c5be2a337a3a0a192a7ebee846fb465e3 Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Fri, 28 Feb 2025 10:10:11 +0100 Subject: [PATCH 4/6] fixed test_chat_agent Signed-off-by: Daniele Martinoli --- docs/source/building_applications/rag.md | 9 +- .../meta_reference/tests/test_chat_agent.py | 101 ++++++++++-------- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 08d68fefac..7e56875fd1 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -93,14 +93,7 @@ agent_config = AgentConfig( { "name": "builtin::rag/knowledge_search", "args": { - # 'documents_db_id' holds the ID of the registered vector database - # where the provided documents will be ingested. This argument is mandatory - # when the 'documents' parameter is provided in a 'create_turn' invocation. - # When provided, 'documents_db_id' will also be used to extract contextual information - # for the query. - "documents_db_id": vector_db_id, - # Optionally, the 'vector_db_ids' argument can specify additional vector databases - # to use at query time. + "vector_db_ids": vector_db_id, }, } ], diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index b802937b6e..84ab364b7e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -16,10 +16,11 @@ AgentTurnResponseTurnCompletePayload, StepType, ) -from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.content_types import URL, TextDelta from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, + ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, LogProbConfig, @@ -27,12 +28,15 @@ ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, UserMessage, ) from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.tools import ( + ListToolGroupsResponse, + ListToolsResponse, Tool, ToolDef, ToolGroup, @@ -40,7 +44,7 @@ ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.models.llama.datatypes import BuiltinTool +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -54,36 +58,37 @@ class MockInferenceAPI: async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: async def stream_response(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="start", - delta="", + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="progress", - delta="AI is a fascinating field...", + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text="AI is a fascinating field..."), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="complete", - delta="", - stop_reason="end_of_turn", + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=StopReason.end_of_turn, ) ) @@ -133,35 +138,39 @@ async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: provider_resource_id=toolgroup_id, ) - async def list_tool_groups(self) -> List[ToolGroup]: - return [] - - async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: - if tool_group_id == MEMORY_TOOLGROUP: - return [ - Tool( - identifier=MEMORY_QUERY_TOOL, - provider_resource_id=MEMORY_QUERY_TOOL, - toolgroup_id=MEMORY_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::rag", - parameters=[], - ) - ] - if tool_group_id == CODE_INTERPRETER_TOOLGROUP: - return [ - Tool( - identifier="code_interpreter", - provider_resource_id="code_interpreter", - toolgroup_id=CODE_INTERPRETER_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::code_interpreter", - parameters=[], - ) - ] - return [] + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=[]) + + async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: + if toolgroup_id == MEMORY_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier=MEMORY_QUERY_TOOL, + provider_resource_id=MEMORY_QUERY_TOOL, + toolgroup_id=MEMORY_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::rag", + parameters=[], + ) + ] + ) + if toolgroup_id == CODE_INTERPRETER_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier="code_interpreter", + provider_resource_id="code_interpreter", + toolgroup_id=CODE_INTERPRETER_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::code_interpreter", + parameters=[], + ) + ] + ) + return ListToolsResponse(data=[]) async def get_tool(self, tool_name: str) -> Tool: return Tool( @@ -174,7 +183,7 @@ async def get_tool(self, tool_name: str) -> Tool: parameters=[], ) - async def unregister_tool_group(self, tool_group_id: str) -> None: + async def unregister_tool_group(self, toolgroup_id: str) -> None: pass @@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex chat_agent = await impl.get_agent(response.agent_id) tool_defs, _ = await chat_agent._get_tool_defs() + tool_defs_names = [t.tool_name for t in tool_defs] if expected_memory: - assert MEMORY_QUERY_TOOL in tool_defs + assert MEMORY_QUERY_TOOL in tool_defs_names if expected_code_interpreter: - assert BuiltinTool.code_interpreter in tool_defs + assert BuiltinTool.code_interpreter in tool_defs_names if expected_memory and expected_code_interpreter: # override the tools for turn new_tool_defs, _ = await chat_agent._get_tool_defs( @@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex ) ] ) - assert MEMORY_QUERY_TOOL in new_tool_defs - assert BuiltinTool.code_interpreter not in new_tool_defs + new_tool_defs_names = [t.tool_name for t in new_tool_defs] + assert MEMORY_QUERY_TOOL in new_tool_defs_names + assert BuiltinTool.code_interpreter not in new_tool_defs_names From 2d0ad6ba3f7f392689134dee26697c5c7f06a376 Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Fri, 28 Feb 2025 10:11:56 +0100 Subject: [PATCH 5/6] fixed RAG doc (broken URL) Signed-off-by: Daniele Martinoli --- docs/source/building_applications/rag.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 7e56875fd1..47871ac9f8 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -93,7 +93,7 @@ agent_config = AgentConfig( { "name": "builtin::rag/knowledge_search", "args": { - "vector_db_ids": vector_db_id, + "vector_db_ids": [vector_db_id], }, } ], From ff3384dad6e1721fd92432a630b426cd767938ff Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Fri, 28 Feb 2025 11:39:23 +0100 Subject: [PATCH 6/6] In case of missing provider_id, use the first one (if any) to register an ephemeral vector db Signed-off-by: Daniele Martinoli --- llama_stack/distribution/routers/routing_tables.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 80e9ecb7cb..1a1142c52b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -309,13 +309,14 @@ async def register_vector_db( if provider_vector_db_id is None: provider_vector_db_id = vector_db_id if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: + if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) + raise ValueError("No provider available. Please configure a vector_io provider.") model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found")