Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 71 additions & 61 deletions docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -4229,70 +4229,80 @@
]
},
"arguments": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
"oneOf": [
{
"type": "string"
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
]
}
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
]
}
}
]
}
]
}
}
]
},
"arguments_json": {
"type": "string"
}
},
"additionalProperties": false,
Expand Down
52 changes: 28 additions & 24 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2884,30 +2884,34 @@ components:
title: BuiltinTool
- type: string
arguments:
type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: array
items:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
oneOf:
- type: string
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: array
items:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
arguments_json:
type: string
additionalProperties: false
required:
- call_id
Expand Down
9 changes: 8 additions & 1 deletion llama_stack/models/llama/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ class BuiltinTool(Enum):
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment here explaining why there is a Union[str, ...] and how the deprecation process will happen?

arguments_json: Optional[str] = None

@field_validator("tool_name", mode="before")
@classmethod
Expand Down
9 changes: 6 additions & 3 deletions llama_stack/models/llama/llama3/chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# the top-level of this source tree.

import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -203,9 +204,10 @@ def decode_assistant_message_from_content(self, content: str, stop_reason: StopR
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
Expand All @@ -229,6 +231,7 @@ def decode_assistant_message_from_content(self, content: str, stop_reason: StopR
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
Expand Down
7 changes: 2 additions & 5 deletions llama_stack/models/llama/llama3/template_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)

from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall

from .prompt_templates import (
BuiltinToolGenerator,
Expand Down
1 change: 1 addition & 0 deletions llama_stack/providers/inline/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def _convert_non_streaming_results(
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
)
for t in vllm_message.tool_calls
],
Expand Down
10 changes: 3 additions & 7 deletions llama_stack/providers/remote/inference/sambanova/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
Expand Down Expand Up @@ -293,14 +291,12 @@ def convert_to_sambanova_tool_calls(
if not tool_calls:
return []

for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)

compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
Expand Down
8 changes: 3 additions & 5 deletions llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,12 @@ def _convert_to_vllm_tool_calls_in_response(
if not tool_calls:
return []

call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)

return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
Expand Down Expand Up @@ -182,6 +179,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
Expand Down
14 changes: 11 additions & 3 deletions llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,11 @@ async def _convert_message_content(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
Expand Down Expand Up @@ -570,7 +574,7 @@ async def impl(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
Expand Down Expand Up @@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
)
except Exception:
return UnparseableToolCall(
Expand Down Expand Up @@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
Expand Down Expand Up @@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
"multiple tool calls found in a single delta, using the first, ignoring the rest",
stacklevel=2,
)

if not enable_incremental_tool_calls:
Expand Down Expand Up @@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
arguments_json=buffer["arguments"],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/models/test_prompt_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ async def test_completion_message_encoding(self):
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)

async def test_user_provided_system_message(self):
content = "Hello !"
Expand Down