From 549096b264ab71cb7eadb852c7ed62ff00235a19 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 18 Mar 2025 13:44:29 -0700 Subject: [PATCH 1/3] new param arguments_json in ToolCall --- docs/_static/llama-stack-spec.html | 132 ++++++++++-------- docs/_static/llama-stack-spec.yaml | 52 +++---- llama_stack/models/llama/datatypes.py | 4 +- .../models/llama/llama3/chat_format.py | 9 +- .../models/llama/llama3/template_data.py | 7 +- .../providers/inline/inference/vllm/vllm.py | 1 + .../remote/inference/sambanova/sambanova.py | 10 +- .../providers/remote/inference/vllm/vllm.py | 5 +- .../utils/inference/openai_compat.py | 14 +- tests/unit/models/test_prompt_adapter.py | 5 +- 10 files changed, 132 insertions(+), 107 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 210a84b032..77b73fb5c6 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4229,70 +4229,80 @@ ] }, "arguments": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - }, - { - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + }, + { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] } - ] - } - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] } - ] - } + } + ] } - ] - } + } + ] + }, + "arguments_json": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a1eb07444c..5d4680c1f6 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2884,30 +2884,34 @@ components: title: BuiltinTool - type: string arguments: - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: array - items: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: array + items: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + arguments_json: + type: string additionalProperties: false required: - call_id diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index b25bf0ea96..1a9bc15045 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -47,7 +47,9 @@ class BuiltinTool(Enum): class ToolCall(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] - arguments: Dict[str, RecursiveType] + arguments: Union[str, Dict[str, RecursiveType]] + # Temporary field for backwards compatibility + arguments_json: Optional[str] = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 011ccb02a9..2862f85582 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -12,6 +12,7 @@ # the top-level of this source tree. import io +import json import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -203,9 +204,10 @@ def decode_assistant_message_from_content(self, content: str, stop_reason: StopR # This code tries to handle that case if tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] - tool_arguments = { - "query": list(tool_arguments.values())[0], - } + if isinstance(tool_arguments, dict): + tool_arguments = { + "query": list(tool_arguments.values())[0], + } else: builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) if builtin_tool_info is not None: @@ -229,6 +231,7 @@ def decode_assistant_message_from_content(self, content: str, stop_reason: StopR call_id=call_id, tool_name=tool_name, arguments=tool_arguments, + arguments_json=json.dumps(tool_arguments), ) ) content = "" diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py index aa16aa009c..076b4adb49 100644 --- a/llama_stack/models/llama/llama3/template_data.py +++ b/llama_stack/models/llama/llama3/template_data.py @@ -11,11 +11,8 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, -) + +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from .prompt_templates import ( BuiltinToolGenerator, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index b59df13d0d..256e0f821d 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -582,6 +582,7 @@ def _convert_non_streaming_results( tool_name=t.function.name, # vLLM function args come back as a string. Llama Stack expects JSON. arguments=json.loads(t.function.arguments), + arguments_json=t.function.arguments, ) for t in vllm_message.tool_calls ], diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index a5e17c2a3d..635a42d385 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -42,9 +42,7 @@ TopKSamplingStrategy, TopPSamplingStrategy, ) -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_stream_response, ) @@ -293,14 +291,12 @@ def convert_to_sambanova_tool_calls( if not tool_calls: return [] - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - compitable_tool_calls = [ ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call_function_arguments, + arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 4d7e66d787..8e494a846a 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -97,7 +97,8 @@ def _convert_to_vllm_tool_calls_in_response( ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call_function_arguments, + arguments=json.loads(call.function.arguments), + arguments_json=call_function_arguments, ) for call in tool_calls ] @@ -181,7 +182,7 @@ async def _process_vllm_chat_completion_stream_response( tool_call=ToolCall( call_id=tool_call_buf.call_id, tool_name=tool_call_buf.tool_name, - arguments=args, + arguments=args_str, ), parse_status=ToolCallParseStatus.succeeded, ), diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 2a362f8cbf..b264c73129 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -529,7 +529,11 @@ async def _convert_message_content( ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: async def impl( content_: InterleavedContent, - ) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: + ) -> Union[ + str, + OpenAIChatCompletionContentPartParam, + List[OpenAIChatCompletionContentPartParam], + ]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -570,7 +574,7 @@ async def impl( OpenAIChatCompletionMessageToolCall( id=tool.call_id, function=OpenAIFunction( - name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, + name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), arguments=json.dumps(tool.arguments), ), type="function", @@ -609,6 +613,7 @@ def convert_tool_call( call_id=tool_call.id, tool_name=tool_call.function.name, arguments=json.loads(tool_call.function.arguments), + arguments_json=tool_call.function.arguments, ) except Exception: return UnparseableToolCall( @@ -759,6 +764,7 @@ def _convert_openai_tool_calls( call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream( # ChatCompletionResponseEvent only supports one per stream if len(choice.delta.tool_calls) > 1: warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2 + "multiple tool calls found in a single delta, using the first, ignoring the rest", + stacklevel=2, ) if not enable_incremental_tool_calls: @@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream( call_id=buffer["call_id"], tool_name=buffer["name"], arguments=arguments, + arguments_json=buffer["arguments"], ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index c3755e2cb8..0e2780e50a 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -165,7 +165,10 @@ async def test_completion_message_encoding(self): request.model = MODEL request.tool_config.tool_prompt_format = ToolPromptFormat.json prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) + self.assertIn( + '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', + prompt, + ) async def test_user_provided_system_message(self): content = "Hello !" From b0425c84b5be597d85e476e2c554161921ca4ad1 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 18 Mar 2025 13:57:36 -0700 Subject: [PATCH 2/3] fix vllm bugs --- llama_stack/providers/remote/inference/vllm/vllm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8e494a846a..6522820e02 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -89,16 +89,12 @@ def _convert_to_vllm_tool_calls_in_response( if not tool_calls: return [] - call_function_arguments = None - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - return [ ToolCall( call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), - arguments_json=call_function_arguments, + arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -182,7 +178,8 @@ async def _process_vllm_chat_completion_stream_response( tool_call=ToolCall( call_id=tool_call_buf.call_id, tool_name=tool_call_buf.tool_name, - arguments=args_str, + arguments=args, + arguments_json=args_str, ), parse_status=ToolCallParseStatus.succeeded, ), From 702b003764f0b9915e754f2533e4611c7568c296 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 19 Mar 2025 10:02:06 -0700 Subject: [PATCH 3/3] added documentation --- llama_stack/models/llama/datatypes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index 1a9bc15045..9842d79807 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -47,8 +47,13 @@ class BuiltinTool(Enum): class ToolCall(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] + # Plan is to deprecate the Dict in favor of a JSON string + # that is parsed on the client side instead of trying to manage + # the recursive type here. + # Making this a union so that client side can start prepping for this change. + # Eventually, we will remove both the Dict and arguments_json field, + # and arguments will just be a str arguments: Union[str, Dict[str, RecursiveType]] - # Temporary field for backwards compatibility arguments_json: Optional[str] = None @field_validator("tool_name", mode="before")