Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use regex_lite::Regex;
use reqwest::StatusCode;
use serde_json::json;
use std::pin::Pin;
Expand Down Expand Up @@ -36,6 +37,7 @@ pub(crate) async fn stream_chat_completions(
model_family: &ModelFamily,
client: &reqwest::Client,
provider: &ModelProviderInfo,
parallel_tool_calls: bool,
) -> Result<ResponseStream> {
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
Expand Down Expand Up @@ -277,6 +279,13 @@ pub(crate) async fn stream_chat_completions(
"tools": tools_json,
});

if parallel_tool_calls && let Some(obj) = payload.as_object_mut() {
obj.insert(
"parallel_tool_calls".to_string(),
serde_json::Value::Bool(true),
);
}

if let Some(schema) = &prompt.output_schema
&& let Some(obj) = payload.as_object_mut()
{
Expand Down Expand Up @@ -602,9 +611,52 @@ async fn process_chat_sse<S>(
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
"stop" => {
let mut reasoning_emitted = false;

if !fn_call_state.active && !assistant_text.is_empty() {
let (cleaned_text, parsed_calls) =
extract_embedded_tool_calls(&assistant_text);

if !parsed_calls.is_empty() {
assistant_text = cleaned_text;

if !reasoning_text.is_empty() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut reasoning_text),
}]),
encrypted_content: None,
};
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await;
reasoning_emitted = true;
}

for call in parsed_calls {
let call_id = call
.call_id
.unwrap_or_else(|| format!("tool_call_{}", Uuid::new_v4()));
let item = ResponseItem::FunctionCall {
id: None,
name: call.name,
arguments: call.arguments,
call_id,
};
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await;
}
}
}

// Regular turn without tool-call. Emit the final assistant message
// as a single OutputItemDone so non-delta consumers see the result.
if !assistant_text.is_empty() {
let has_message_content =
assistant_text.chars().any(|c| !c.is_whitespace());
if has_message_content {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
Expand All @@ -613,9 +665,12 @@ async fn process_chat_sse<S>(
id: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
} else {
assistant_text.clear();
}

// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
if !reasoning_text.is_empty() {
if !reasoning_text.is_empty() && !reasoning_emitted {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
Expand Down Expand Up @@ -647,6 +702,62 @@ async fn process_chat_sse<S>(
}
}

#[derive(Debug)]
struct EmbeddedToolCall {
name: String,
arguments: String,
call_id: Option<String>,
}

fn extract_embedded_tool_calls(text: &str) -> (String, Vec<EmbeddedToolCall>) {
let regex = match Regex::new(r"<tool_call>\s*([\s\S]*?)\s*</tool_call>") {
Ok(regex) => regex,
Err(_) => return (text.to_string(), Vec::new()),
};
let mut cleaned = String::with_capacity(text.len());
let mut tool_calls = Vec::new();
let mut last_index = 0;

for capture in regex.captures_iter(text) {
if let Some(m) = capture.get(0) {
cleaned.push_str(&text[last_index..m.start()]);

let inner = capture.get(1).map(|c| c.as_str().trim()).unwrap_or("");
match serde_json::from_str::<serde_json::Value>(inner) {
Ok(obj) => {
let name = obj.get("name").and_then(|v| v.as_str());
let arguments_value = obj.get("arguments");
if let (Some(name), Some(arguments_value)) = (name, arguments_value) {
let arguments = if let Some(s) = arguments_value.as_str() {
s.to_string()
} else {
serde_json::to_string(arguments_value).unwrap_or_default()
};
let call_id = obj
.get("id")
.and_then(|v| v.as_str())
.map(std::string::ToString::to_string);
tool_calls.push(EmbeddedToolCall {
name: name.to_string(),
arguments,
call_id,
});
} else {
cleaned.push_str(m.as_str());
}
}
Err(_) => cleaned.push_str(m.as_str()),
}

last_index = m.end();
}
}

cleaned.push_str(&text[last_index..]);

(cleaned, tool_calls)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -693,6 +804,51 @@ mod tests {
"unexpected fallback call_id prefix: {call_id}"
);
}

#[tokio::test]
async fn converts_embedded_tool_call_tags_into_function_calls() {
let chunks = vec![
Ok::<Bytes, CodexErr>(Bytes::from_static(
b"data: {\"choices\":[{\"delta\":{\"content\":\"<tool_call>{\\\"name\\\":\\\"run\\\",\\\"arguments\\\":{\\\"cmd\\\":\\\"echo\\\"}}<\\/tool_call>\"}}]}\n\n",
)),
Ok::<Bytes, CodexErr>(Bytes::from_static(
b"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n",
)),
];

let stream = stream::iter(chunks);
let (tx, mut rx) = mpsc::channel(8);
let handle = tokio::spawn(async move {
process_chat_sse(stream, tx, Duration::from_secs(5)).await;
});

let mut observed = Vec::new();
while let Some(event) = rx.recv().await {
match event.expect("stream event") {
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
}) => {
observed.push((name, arguments, call_id));
}
ResponseEvent::OutputItemDone(ResponseItem::Message { .. }) => {
panic!("unexpected assistant message emitted instead of tool call");
}
ResponseEvent::Completed { .. } => break,
_ => {}
}
}

handle.await.expect("process_chat_sse task");

assert_eq!(observed.len(), 1);
let (name, arguments, call_id) = observed.into_iter().next().unwrap();
assert_eq!(name, "run");
assert_eq!(arguments, "{\"cmd\":\"echo\"}");
assert!(call_id.starts_with("tool_call_"));
}
}

/// Optional client-side aggregation helper
Expand Down
3 changes: 2 additions & 1 deletion codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl ModelClient {
&self.config.model_family,
&self.client,
&self.provider,
self.config.parallel_tool_calls,
)
.await?;

Expand Down Expand Up @@ -216,7 +217,7 @@ impl ModelClient {
input: &input_with_instructions,
tools: &tools_json,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: self.config.parallel_tool_calls,
reasoning,
store: azure_workaround,
stream: true,
Expand Down
47 changes: 46 additions & 1 deletion codex-rs/core/src/client_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use tokio::sync::mpsc;
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");

const QWEN3_CODER_SYSTEM_INSTRUCTIONS: &str = "When tools are available, call them directly using the `tool_calls` format without extra explanation. Do not output `<think>` blocks.";

/// API request payload for a single model turn
#[derive(Default, Debug, Clone)]
pub struct Prompt {
Expand Down Expand Up @@ -53,21 +55,50 @@ impl Prompt {
OpenAiTool::Freeform(f) => f.name == "apply_patch",
_ => false,
});
if self.base_instructions_override.is_none()
let mut instructions = if self.base_instructions_override.is_none()
&& model.needs_special_apply_patch_instructions
&& !is_apply_patch_tool_present
{
Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"))
} else {
Cow::Borrowed(base)
};

if let Some(extra) = extra_instructions_for_model(model) {
instructions = match instructions {
Cow::Borrowed(base) => {
let mut owned = base.to_string();
if !owned.ends_with('\n') {
owned.push('\n');
}
owned.push_str(extra);
Cow::Owned(owned)
}
Cow::Owned(mut owned) => {
if !owned.ends_with('\n') {
owned.push('\n');
}
owned.push_str(extra);
Cow::Owned(owned)
}
};
}

instructions
}

pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
self.input.clone()
}
}

fn extra_instructions_for_model(model: &ModelFamily) -> Option<&'static str> {
match model.slug.as_str() {
"qwen/qwen3-coder-30b" => Some(QWEN3_CODER_SYSTEM_INSTRUCTIONS),
_ => None,
}
}

#[derive(Debug)]
pub enum ResponseEvent {
Created,
Expand Down Expand Up @@ -368,4 +399,18 @@ mod tests {
let v = serde_json::to_value(&req).expect("json");
assert!(v.get("text").is_none());
}

#[test]
fn adds_qwen3_coder_system_instructions() {
let prompt = Prompt::default();
let model_family = find_family_for_model("qwen/qwen3-coder-30b").expect("known model");

let instructions = prompt.get_full_instructions(&model_family);
let instructions = instructions.as_ref();

assert!(
instructions.ends_with(QWEN3_CODER_SYSTEM_INSTRUCTIONS),
"expected qwen-specific guidance to be appended"
);
}
}
17 changes: 17 additions & 0 deletions codex-rs/core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ pub struct Config {
/// model family's default preference.
pub include_apply_patch_tool: bool,

/// Enable the parallel tool call mode when supported by the provider.
pub parallel_tool_calls: bool,

pub tools_web_search_request: bool,

pub use_experimental_streamable_shell_tool: bool,
Expand Down Expand Up @@ -694,6 +697,9 @@ pub struct ConfigToml {
/// Optional verbosity control for GPT-5 models (Responses API `text.verbosity`).
pub model_verbosity: Option<Verbosity>,

/// Allow the model to request multiple tools within a single turn.
pub parallel_tool_calls: Option<bool>,

/// Override to force-enable reasoning summaries for the configured model.
pub model_supports_reasoning_summaries: Option<bool>,

Expand Down Expand Up @@ -857,6 +863,7 @@ pub struct ConfigOverrides {
pub include_plan_tool: Option<bool>,
pub include_apply_patch_tool: Option<bool>,
pub include_view_image_tool: Option<bool>,
pub parallel_tool_calls: Option<bool>,
pub show_raw_agent_reasoning: Option<bool>,
pub tools_web_search_request: Option<bool>,
}
Expand Down Expand Up @@ -885,6 +892,7 @@ impl Config {
include_plan_tool,
include_apply_patch_tool,
include_view_image_tool,
parallel_tool_calls,
show_raw_agent_reasoning,
tools_web_search_request: override_tools_web_search_request,
} = overrides;
Expand Down Expand Up @@ -960,6 +968,10 @@ impl Config {
.or(cfg.tools.as_ref().and_then(|t| t.view_image))
.unwrap_or(true);

let parallel_tool_calls = parallel_tool_calls
.or(cfg.parallel_tool_calls)
.unwrap_or(false);

let model = model
.or(config_profile.model)
.or(cfg.model)
Expand Down Expand Up @@ -1052,6 +1064,7 @@ impl Config {
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
include_plan_tool: include_plan_tool.unwrap_or(false),
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
parallel_tool_calls,
tools_web_search_request,
use_experimental_streamable_shell_tool: cfg
.experimental_use_exec_command_tool
Expand Down Expand Up @@ -1801,6 +1814,7 @@ model_verbosity = "high"
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
parallel_tool_calls: false,
tools_web_search_request: false,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
Expand Down Expand Up @@ -1860,6 +1874,7 @@ model_verbosity = "high"
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
parallel_tool_calls: false,
tools_web_search_request: false,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
Expand Down Expand Up @@ -1934,6 +1949,7 @@ model_verbosity = "high"
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
parallel_tool_calls: false,
tools_web_search_request: false,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
Expand Down Expand Up @@ -1994,6 +2010,7 @@ model_verbosity = "high"
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
parallel_tool_calls: false,
tools_web_search_request: false,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
Expand Down
Loading
Loading