From 01b014dd5e95dced9d15b728faedea17ef8d73ee Mon Sep 17 00:00:00 2001 From: rislam Date: Tue, 3 Mar 2026 23:46:56 -0800 Subject: [PATCH] Normalize tool_calls and gate parser tool-calls to tool-enabled requests Convert parser tool_calls to JSON-safe dicts and suppress incidental tool-call parsing when tools are not requested, preventing client serialization failures and unintended tool_calls/finish_reason behavior. --- .../endpoints/chat_completions.py | 51 +++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py index a4cb61fb962..3e427bcaf03 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import asyncio +import json import logging import time import traceback @@ -12,6 +13,38 @@ logger = logging.getLogger(__name__) + +def _get_field(obj, key, default=None): + """Read a field from dict-like or object-like values.""" + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _normalize_tool_calls(tool_calls): + """Normalize tool calls to OpenAI-compatible JSON primitives.""" + normalized = [] + for call in tool_calls or []: + fn = _get_field(call, "function", {}) or {} + fn_name = _get_field(fn, "name") + fn_args = _get_field(fn, "arguments", "") + if fn_name is None: + continue + if not isinstance(fn_args, str): + try: + fn_args = json.dumps(fn_args, ensure_ascii=False) + except TypeError: + fn_args = str(fn_args) + normalized.append( + { + "id": str(_get_field(call, "id", f"call_{uuid.uuid4().hex[:24]}")), + "type": "function", + "function": {"name": str(fn_name), "arguments": fn_args}, + } + ) + return normalized + + try: from flask import Blueprint, current_app, jsonify, request @@ -26,6 +59,8 @@ async def chat_completions(): parsers = current_app.config['parsers'] req = request.get_json() + tools = req.get("tools", None) + tools_requested = bool(tools) # --- 1. Parse Messages --- messages = req.get("messages") @@ -36,7 +71,7 @@ async def chat_completions(): try: prompt_tokens = tokenizer.apply_chat_template( - messages, tokenize=True, add_generation_prompt=True, tools=req.get("tools", None) + messages, tokenize=True, add_generation_prompt=True, tools=tools ) except (AttributeError, AssertionError): warnings.warn( @@ -168,15 +203,23 @@ async def chat_completions(): for parser in parsers: if parser not in PARSER_MAPPING: raise ValueError(f"Parser {parser} not found in PARSER_MAPPING") - message_text, new_info = PARSER_MAPPING[parser].parse( - message_text, tools=req.get("tools", None) + prev_text = message_text + parsed_text, new_info = PARSER_MAPPING[parser].parse( + message_text, tools=tools ) + if "tool_calls" in new_info: + new_info["tool_calls"] = _normalize_tool_calls(new_info.get("tool_calls", [])) + if not tools_requested: + # Ignore incidental tool-call syntax in plain chat mode. + parsed_text = prev_text + new_info.pop("tool_calls", None) + message_text = parsed_text assert not ( metadata.keys() & new_info.keys() ), "Multiple parsers found the same information." metadata.update(new_info) message = {"role": "assistant", "content": message_text} - if "tool_calls" in metadata: + if metadata.get("tool_calls", []): message["tool_calls"] = metadata["tool_calls"] if "reasoning" in metadata: message["reasoning"] = metadata["reasoning"]