Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ exclude = [
"^src/llama_stack/core/routing_tables/",
# Provider directories - Inline
"^src/llama_stack/providers/inline/datasetio/localfs/",
"^src/llama_stack/providers/inline/responses/builtin/",
"^src/llama_stack/providers/inline/safety/code_scanner/",
"^src/llama_stack/providers/inline/safety/llama_guard/",
"^src/llama_stack/providers/inline/scoring/basic/",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ async def create_openai_response(
presence_penalty: float | None = None,
extra_body: dict | None = None,
stream_options: ResponseStreamOptions | None = None,
):
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
stream = bool(stream)
background = bool(background)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
Expand Down Expand Up @@ -841,7 +841,9 @@ async def _create_background_response(
created_at = int(time.time())

# Normalize input to list format for storage
input_items = [OpenAIResponseMessage(content=input, role="user")] if isinstance(input, str) else input
input_items: list[OpenAIResponseInput] = (
[OpenAIResponseMessage(content=input, role="user")] if isinstance(input, str) else input
)

# Create initial queued response
queued_response = OpenAIResponseObject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionCustomToolCall,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionResponseMessage,
OpenAIChatCompletionToolCall,
Expand Down Expand Up @@ -176,7 +177,7 @@ def extract_openai_error(exc: Exception) -> tuple[str, str]:
raw_message = body.get("message")

if raw_code and isinstance(raw_code, str):
final_code: str = _RESPONSES_API_ERROR_CODES.get(raw_code, raw_code)
final_code: str = _RESPONSES_API_ERROR_CODES[raw_code] if raw_code in _RESPONSES_API_ERROR_CODES else raw_code
else:
final_code = "server_error"

Expand Down Expand Up @@ -424,10 +425,11 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
async for stream_event in self._process_tools(output_messages):
yield stream_event

chat_tool_choice = None
chat_tool_choice: str | dict[str, Any] | None = None
# Track allowed tools for filtering (persists across iterations)
allowed_tool_names: set[str] | None = None
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
# check truthiness of self.ctx.chat_tools to avoid len(None)
if self.ctx.tool_choice and self.ctx.chat_tools:
processed_tool_choice = await _process_tool_choice(
self.ctx.chat_tools,
self.ctx.tool_choice,
Expand Down Expand Up @@ -484,7 +486,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
)
# Filter tools to only allowed ones if tool_choice specified an allowed list
effective_tools = self.ctx.chat_tools
if allowed_tool_names is not None:
if allowed_tool_names is not None and self.ctx.chat_tools is not None:
effective_tools = [
tool
for tool in self.ctx.chat_tools
Expand Down Expand Up @@ -525,7 +527,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
parallel_tool_calls=effective_parallel_tool_calls,
reasoning_effort=self.reasoning.effort if self.reasoning else None,
safety_identifier=self.safety_identifier,
service_tier=self.service_tier,
service_tier=ServiceTier(self.service_tier) if self.service_tier else None,
max_completion_tokens=remaining_output_tokens,
prompt_cache_key=self.prompt_cache_key,
top_logprobs=self.top_logprobs,
Expand Down Expand Up @@ -1224,13 +1226,14 @@ async def _process_streaming_chunks(
message_item_id=message_item_id,
tool_call_item_ids=tool_call_item_ids,
content_part_emitted=content_part_emitted,
logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None,
logprobs=chat_response_logprobs if chat_response_logprobs else None,
service_tier=chunk_service_tier,
)

def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
"""Build OpenAIChatCompletion from ChatCompletionResult."""
# Convert collected chunks to complete response
tool_calls: list[OpenAIChatCompletionToolCall | OpenAIChatCompletionCustomToolCall] | None
if result.tool_calls:
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
else:
Expand All @@ -1247,7 +1250,7 @@ def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatComp
message=assistant_message,
finish_reason=result.finish_reason,
index=0,
logprobs=result.logprobs,
logprobs=OpenAIChoiceLogprobs(content=result.logprobs) if result.logprobs else None,
)
],
created=result.created,
Expand Down Expand Up @@ -1409,8 +1412,8 @@ def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(
ChatCompletionToolParam(type="function", function=input_tool.model_dump(exclude_none=True))
) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
ChatCompletionToolParam(type="function", function=input_tool.model_dump(exclude_none=True)) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
)
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
Expand Down Expand Up @@ -1446,7 +1449,8 @@ async def _process_mcp_tool(
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process an MCP tool configuration and emit appropriate streaming events."""
# Resolve connector_id to server_url if provided
mcp_tool = await resolve_mcp_connector_id(mcp_tool, self.connectors_api)
if self.connectors_api is not None:
mcp_tool = await resolve_mcp_connector_id(mcp_tool, self.connectors_api)

# Emit mcp_list_tools.in_progress
self.sequence_number += 1
Expand All @@ -1467,15 +1471,21 @@ async def _process_mcp_tool(
# Call list_mcp_tools
tool_defs = None
list_id = f"mcp_list_{uuid.uuid4()}"

# Get session manager from tool_executor if available (fix for #4452)
session_manager = getattr(self.tool_executor, "mcp_session_manager", None)

if not mcp_tool.server_url:
raise ValueError(
f"Failed to list MCP tools for server '{mcp_tool.server_label}': server_url is not set"
)

attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}

# Get session manager from tool_executor if available (fix for #4452)
session_manager = getattr(self.tool_executor, "mcp_session_manager", None)

# TODO: follow semantic conventions for Open Telemetry tool spans
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
with tracer.start_as_current_span("list_mcp_tools", attributes=attributes):
Expand Down Expand Up @@ -1681,7 +1691,7 @@ async def _process_tool_choice(

elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
final_tools = []
final_tools: list[dict[str, Any]] = []
for tool in tool_choice.tools:
match tool.get("type"):
case "function":
Expand Down Expand Up @@ -1714,19 +1724,18 @@ async def _process_tool_choice(
else:
# Handle specific tool choice by type
# Each case validates the tool exists in chat_tools before returning
tool_name = getattr(tool_choice, "name", None)
match tool_choice:
case OpenAIResponseInputToolChoiceCustomTool():
if tool_name and tool_name not in chat_tool_names:
logger.warning(f"Tool {tool_name} not found in chat tools")
if tool_choice.name not in chat_tool_names:
logger.warning(f"Tool {tool_choice.name} not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceCustomTool(name=tool_name)
return OpenAIChatCompletionToolChoiceCustomTool(name=tool_choice.name)

case OpenAIResponseInputToolChoiceFunctionTool():
if tool_name and tool_name not in chat_tool_names:
logger.warning(f"Tool {tool_name} not found in chat tools")
if tool_choice.name not in chat_tool_names:
logger.warning(f"Tool {tool_choice.name} not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_name)
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice.name)

case OpenAIResponseInputToolChoiceFileSearch():
if "file_search" not in chat_tool_names:
Expand All @@ -1741,21 +1750,25 @@ async def _process_tool_choice(
return OpenAIChatCompletionToolChoiceFunctionTool(name="web_search")

case OpenAIResponseInputToolChoiceMCPTool():
tool_choice = convert_mcp_tool_choice(
mcp_result = convert_mcp_tool_choice(
chat_tool_names,
tool_choice.server_label,
server_label_to_tools,
tool_name,
tool_choice.name,
)
if isinstance(tool_choice, dict):
if isinstance(mcp_result, dict):
# for single tool choice, return as function tool choice
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice["function"]["name"])
elif isinstance(tool_choice, list):
function_info = mcp_result["function"]
if not isinstance(function_info, dict):
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name=function_info["name"])
elif isinstance(mcp_result, list):
# for multiple tool choices, return as allowed tools
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=tool_choice,
tools=mcp_result,
mode="required",
)
return None


async def resolve_mcp_connector_id(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ async def _execute_tool(
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool

mcp_tool = mcp_tool_to_server[function_name]
if not mcp_tool.server_url:
raise ValueError(f"Failed to invoke MCP tool {function_name}: server_url is not set")
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
def _json_equal(a: str, b: str) -> bool:
"""Compare two JSON strings by value, falling back to string comparison."""
try:
return json.loads(a) == json.loads(b)
# json.loads() returns Any, so == on two Any values is also Any
return cast(bool, json.loads(a) == json.loads(b))
except (json.JSONDecodeError, TypeError):
return a == b

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def convert_mcp_tool_choice(
server_label: str | None = None,
server_label_to_tools: dict[str, list[str]] | None = None,
tool_name: str | None = None,
) -> dict[str, str] | list[dict[str, str]]:
) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]] | None:
"""Convert a responses tool choice of type mcp to a chat completions compatible function tool choice."""

if tool_name:
Expand All @@ -589,6 +589,8 @@ def convert_mcp_tool_choice(
tool_names = server_label_to_tools.get(server_label, [])
if not tool_names:
return None
matching_tools = [{"type": "function", "function": {"name": tool_name}} for tool_name in tool_names]
matching_tools: list[dict[str, str | dict[str, str]]] = [
{"type": "function", "function": {"name": name}} for name in tool_names
]
return matching_tools
return []
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
OpenAIResponseObjectWithInput,
Order,
ResponseInputItemNotFoundError,
ResponseItemInclude,
ResponseNotFoundError,
)
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
Expand Down Expand Up @@ -255,7 +256,7 @@ async def list_response_input_items(
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
include: list[ResponseItemInclude] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
Expand Down
Loading