diff --git a/rig/rig-core/src/completion/request.rs b/rig/rig-core/src/completion/request.rs index e4cf249ee..c786b5d1d 100644 --- a/rig/rig-core/src/completion/request.rs +++ b/rig/rig-core/src/completion/request.rs @@ -213,6 +213,36 @@ pub struct ToolDefinition { pub parameters: serde_json::Value, } +/// Provider-native tool definition. +/// +/// Stored under `additional_params.tools` and forwarded by providers that support +/// provider-managed tools. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct ProviderToolDefinition { + /// Tool type/kind name as expected by the target provider (for example `web_search`). + #[serde(rename = "type")] + pub kind: String, + /// Additional provider-specific configuration for this hosted tool. + #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")] + pub config: serde_json::Map, +} + +impl ProviderToolDefinition { + /// Creates a provider-hosted tool definition by type. + pub fn new(kind: impl Into) -> Self { + Self { + kind: kind.into(), + config: serde_json::Map::new(), + } + } + + /// Adds a provider-specific configuration key/value. + pub fn with_config(mut self, key: impl Into, value: serde_json::Value) -> Self { + self.config.insert(key.into(), value); + self + } +} + // ================================================================ // Implementations // ================================================================ @@ -606,6 +636,56 @@ impl CompletionRequest { content: OneOrMany::many(messages).expect("There will be atleast one document"), }) } + + /// Adds a provider-hosted tool by storing it in `additional_params.tools`. + pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self { + self.additional_params = + merge_provider_tools_into_additional_params(self.additional_params, vec![tool]); + self + } + + /// Adds provider-hosted tools by storing them in `additional_params.tools`. + pub fn with_provider_tools(mut self, tools: Vec) -> Self { + self.additional_params = + merge_provider_tools_into_additional_params(self.additional_params, tools); + self + } +} + +fn merge_provider_tools_into_additional_params( + additional_params: Option, + provider_tools: Vec, +) -> Option { + if provider_tools.is_empty() { + return additional_params; + } + + let mut provider_tools_json = provider_tools + .into_iter() + .map(|ProviderToolDefinition { kind, mut config }| { + // Force the provider tool type from the strongly-typed field. + config.insert("type".to_string(), serde_json::Value::String(kind)); + serde_json::Value::Object(config) + }) + .collect::>(); + + let mut params_map = match additional_params { + Some(serde_json::Value::Object(map)) => map, + Some(serde_json::Value::Bool(stream)) => { + let mut map = serde_json::Map::new(); + map.insert("stream".to_string(), serde_json::Value::Bool(stream)); + map + } + _ => serde_json::Map::new(), + }; + + let mut merged_tools = match params_map.remove("tools") { + Some(serde_json::Value::Array(existing)) => existing, + _ => Vec::new(), + }; + merged_tools.append(&mut provider_tools_json); + params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools)); + Some(serde_json::Value::Object(params_map)) } /// Builder struct for constructing a completion request. @@ -660,6 +740,7 @@ pub struct CompletionRequestBuilder { chat_history: Vec, documents: Vec, tools: Vec, + provider_tools: Vec, temperature: Option, max_tokens: Option, tool_choice: Option, @@ -677,6 +758,7 @@ impl CompletionRequestBuilder { chat_history: Vec::new(), documents: Vec::new(), tools: Vec::new(), + provider_tools: Vec::new(), temperature: None, max_tokens: None, tool_choice: None, @@ -747,6 +829,19 @@ impl CompletionRequestBuilder { .fold(self, |builder, tool| builder.tool(tool)) } + /// Adds a provider-hosted tool to the completion request. + pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self { + self.provider_tools.push(tool); + self + } + + /// Adds provider-hosted tools to the completion request. + pub fn provider_tools(self, tools: Vec) -> Self { + tools + .into_iter() + .fold(self, |builder, tool| builder.provider_tool(tool)) + } + /// Adds additional parameters to the completion request. /// This can be used to set additional provider-specific parameters. For example, /// Cohere's completion models accept a `connectors` parameter that can be used to @@ -831,6 +926,10 @@ impl CompletionRequestBuilder { pub fn build(self) -> CompletionRequest { let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat()) .expect("There will always be atleast the prompt"); + let additional_params = merge_provider_tools_into_additional_params( + self.additional_params, + self.provider_tools, + ); CompletionRequest { model: self.request_model, @@ -841,7 +940,7 @@ impl CompletionRequestBuilder { temperature: self.temperature, max_tokens: self.max_tokens, tool_choice: self.tool_choice, - additional_params: self.additional_params, + additional_params, output_schema: self.output_schema, } } diff --git a/rig/rig-core/src/providers/anthropic/completion.rs b/rig/rig-core/src/providers/anthropic/completion.rs index 4e47e5e90..84d185e13 100644 --- a/rig/rig-core/src/providers/anthropic/completion.rs +++ b/rig/rig-core/src/providers/anthropic/completion.rs @@ -998,7 +998,7 @@ struct AnthropicCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] tool_choice: Option, #[serde(skip_serializing_if = "Vec::is_empty")] - tools: Vec, + tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] output_config: Option, #[serde(flatten, skip_serializing_if = "Option::is_none")] @@ -1052,7 +1052,7 @@ impl TryFrom> for AnthropicCompletionRequest { fn try_from(params: AnthropicRequestParams<'_>) -> Result { let AnthropicRequestParams { model, - request: req, + request: mut req, prompt_caching, } = params; @@ -1074,7 +1074,14 @@ impl TryFrom> for AnthropicCompletionRequest { .map(Message::try_from) .collect::, _>>()?; - let tools = req + let mut additional_params_payload = req + .additional_params + .take() + .unwrap_or(serde_json::Value::Null); + let mut additional_tools = + extract_tools_from_additional_params(&mut additional_params_payload)?; + + let mut tools = req .tools .into_iter() .map(|tool| ToolDefinition { @@ -1082,7 +1089,9 @@ impl TryFrom> for AnthropicCompletionRequest { description: Some(tool.description), input_schema: tool.parameters, }) - .collect::>(); + .map(serde_json::to_value) + .collect::, _>>()?; + tools.append(&mut additional_tools); // Convert system prompt to array format for cache_control support let mut system = if let Some(preamble) = req.preamble { @@ -1123,11 +1132,31 @@ impl TryFrom> for AnthropicCompletionRequest { tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()), tools, output_config, - additional_params: req.additional_params, + additional_params: if additional_params_payload.is_null() { + None + } else { + Some(additional_params_payload) + }, }) } } +fn extract_tools_from_additional_params( + additional_params: &mut serde_json::Value, +) -> Result, CompletionError> { + if let Some(map) = additional_params.as_object_mut() + && let Some(raw_tools) = map.remove("tools") + { + return serde_json::from_value::>(raw_tools).map_err(|err| { + CompletionError::RequestError( + format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(), + ) + }); + } + + Ok(Vec::new()) +} + impl completion::CompletionModel for CompletionModel where T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static, diff --git a/rig/rig-core/src/providers/anthropic/streaming.rs b/rig/rig-core/src/providers/anthropic/streaming.rs index 6f5c0f0a8..7028dee3d 100644 --- a/rig/rig-core/src/providers/anthropic/streaming.rs +++ b/rig/rig-core/src/providers/anthropic/streaming.rs @@ -1,7 +1,7 @@ use async_stream::stream; use futures::StreamExt; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{Value, json}; use tracing::{Level, enabled, info_span}; use tracing_futures::Instrument; @@ -127,7 +127,7 @@ where { pub(crate) async fn stream( &self, - completion_request: CompletionRequest, + mut completion_request: CompletionRequest, ) -> Result, CompletionError> { let request_model = completion_request @@ -209,26 +209,37 @@ where merge_inplace(&mut body, json!({ "temperature": temperature })); } - if !completion_request.tools.is_empty() { + let mut additional_params_payload = completion_request + .additional_params + .take() + .unwrap_or(Value::Null); + let mut additional_tools = + extract_tools_from_additional_params(&mut additional_params_payload)?; + + let mut tools = completion_request + .tools + .into_iter() + .map(|tool| ToolDefinition { + name: tool.name, + description: Some(tool.description), + input_schema: tool.parameters, + }) + .map(serde_json::to_value) + .collect::, _>>()?; + tools.append(&mut additional_tools); + + if !tools.is_empty() { merge_inplace( &mut body, json!({ - "tools": completion_request - .tools - .into_iter() - .map(|tool| ToolDefinition { - name: tool.name, - description: Some(tool.description), - input_schema: tool.parameters, - }) - .collect::>(), + "tools": tools, "tool_choice": ToolChoice::Auto, }), ); } - if let Some(ref params) = completion_request.additional_params { - merge_inplace(&mut body, params.clone()) + if !additional_params_payload.is_null() { + merge_inplace(&mut body, additional_params_payload) } if enabled!(Level::TRACE) { @@ -325,6 +336,22 @@ where } } +fn extract_tools_from_additional_params( + additional_params: &mut Value, +) -> Result, CompletionError> { + if let Some(map) = additional_params.as_object_mut() + && let Some(raw_tools) = map.remove("tools") + { + return serde_json::from_value::>(raw_tools).map_err(|err| { + CompletionError::RequestError( + format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(), + ) + }); + } + + Ok(Vec::new()) +} + fn handle_event( event: &StreamingEvent, current_tool_call: &mut Option, diff --git a/rig/rig-core/src/providers/gemini/completion.rs b/rig/rig-core/src/providers/gemini/completion.rs index 17172f060..1c2960f9d 100644 --- a/rig/rig-core/src/providers/gemini/completion.rs +++ b/rig/rig-core/src/providers/gemini/completion.rs @@ -187,55 +187,71 @@ where pub(crate) fn create_request_body( completion_request: CompletionRequest, ) -> Result { - let mut full_history = Vec::new(); + let documents_message = completion_request.normalized_documents(); + + let CompletionRequest { + model: _, + preamble, + chat_history, + documents: _, + tools: function_tools, + temperature, + max_tokens, + tool_choice, + mut additional_params, + output_schema, + } = completion_request; - // Add documents as a user message at the beginning if present - if let Some(documents_message) = completion_request.normalized_documents() { - full_history.push(documents_message); + let mut full_history = Vec::new(); + if let Some(msg) = documents_message { + full_history.push(msg); } + full_history.extend(chat_history); - full_history.extend(completion_request.chat_history); - - let additional_params = completion_request - .additional_params + let mut additional_params_payload = additional_params + .take() .unwrap_or_else(|| Value::Object(Map::new())); + let mut additional_tools = + extract_tools_from_additional_params(&mut additional_params_payload)?; let AdditionalParameters { mut generation_config, additional_params, - } = serde_json::from_value::(additional_params)?; + } = serde_json::from_value::(additional_params_payload)?; // Apply output_schema to generation_config, creating one if needed - if let Some(schema) = completion_request.output_schema { + if let Some(schema) = output_schema { let cfg = generation_config.get_or_insert_with(GenerationConfig::default); cfg.response_mime_type = Some("application/json".to_string()); cfg.response_json_schema = Some(schema.to_value()); } generation_config = generation_config.map(|mut cfg| { - if let Some(temp) = completion_request.temperature { + if let Some(temp) = temperature { cfg.temperature = Some(temp); }; - if let Some(max_tokens) = completion_request.max_tokens { + if let Some(max_tokens) = max_tokens { cfg.max_output_tokens = Some(max_tokens); }; cfg }); - let system_instruction = completion_request.preamble.clone().map(|preamble| Content { + let system_instruction = preamble.clone().map(|preamble| Content { parts: vec![preamble.into()], role: Some(Role::Model), }); - let tools = if completion_request.tools.is_empty() { - None + let mut tools = if function_tools.is_empty() { + Vec::new() } else { - Some(vec![Tool::try_from(completion_request.tools)?]) + vec![serde_json::to_value(Tool::try_from(function_tools)?)?] }; + tools.append(&mut additional_tools); + let tools = if tools.is_empty() { None } else { Some(tools) }; - let tool_config = if let Some(cfg) = completion_request.tool_choice { + let tool_config = if let Some(cfg) = tool_choice { Some(ToolConfig { function_calling_config: Some(FunctionCallingMode::try_from(cfg)?), }) @@ -262,6 +278,22 @@ pub(crate) fn create_request_body( Ok(request) } +fn extract_tools_from_additional_params( + additional_params: &mut Value, +) -> Result, CompletionError> { + if let Some(map) = additional_params.as_object_mut() + && let Some(raw_tools) = map.remove("tools") + { + return serde_json::from_value::>(raw_tools).map_err(|err| { + CompletionError::RequestError( + format!("Invalid Gemini `additional_params.tools` payload: {err}").into(), + ) + }); + } + + Ok(Vec::new()) +} + pub(crate) fn resolve_request_model( default_model: &str, completion_request: &CompletionRequest, @@ -1815,7 +1847,7 @@ pub mod gemini_api_types { pub struct GenerateContentRequest { pub contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, + pub tools: Option>, pub tool_config: Option, /// Optional. Configuration options for model generation and outputs. pub generation_config: Option, diff --git a/rig/rig-core/src/providers/groq.rs b/rig/rig-core/src/providers/groq.rs index a36795cb3..11a667c19 100644 --- a/rig/rig-core/src/providers/groq.rs +++ b/rig/rig-core/src/providers/groq.rs @@ -10,7 +10,7 @@ //! ``` use bytes::Bytes; use http::Request; -use serde_json::Map; +use serde_json::{Map, Value}; use std::collections::HashMap; use tracing::info_span; use tracing_futures::Instrument; @@ -184,7 +184,7 @@ pub(super) struct StreamOptions { impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { type Error = CompletionError; - fn try_from((model, req): (&str, CompletionRequest)) -> Result { + fn try_from((model, mut req): (&str, CompletionRequest)) -> Result { if req.output_schema.is_some() { tracing::warn!("Structured outputs currently not supported for Groq"); } @@ -219,12 +219,17 @@ impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let additional_params: Option = - if let Some(params) = req.additional_params { - Some(serde_json::from_value(params)?) - } else { + let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null); + let native_tools = + extract_native_tools_from_additional_params(&mut additional_params_payload)?; + + let mut additional_params: Option = + if additional_params_payload.is_null() { None + } else { + Some(serde_json::from_value(additional_params_payload)?) }; + apply_native_tools_to_additional_params(&mut additional_params, native_tools); Ok(Self { model: model.to_string(), @@ -244,6 +249,75 @@ impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { } } +fn extract_native_tools_from_additional_params( + additional_params: &mut Value, +) -> Result, CompletionError> { + if let Some(map) = additional_params.as_object_mut() + && let Some(raw_tools) = map.remove("tools") + { + return serde_json::from_value::>(raw_tools).map_err(|err| { + CompletionError::RequestError( + format!("Invalid Groq `additional_params.tools` payload: {err}").into(), + ) + }); + } + + Ok(Vec::new()) +} + +fn apply_native_tools_to_additional_params( + additional_params: &mut Option, + native_tools: Vec, +) { + if native_tools.is_empty() { + return; + } + + let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default); + let extra = params.extra.get_or_insert_with(Map::new); + + let mut compound_custom = match extra.remove("compound_custom") { + Some(Value::Object(map)) => map, + _ => Map::new(), + }; + + let mut enabled_tools = match compound_custom.remove("enabled_tools") { + Some(Value::Array(values)) => values, + _ => Vec::new(), + }; + + for native_tool in native_tools { + let already_enabled = enabled_tools + .iter() + .any(|existing| native_tools_match(existing, &native_tool)); + if !already_enabled { + enabled_tools.push(native_tool); + } + } + + compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools)); + extra.insert( + "compound_custom".to_string(), + Value::Object(compound_custom), + ); +} + +fn native_tools_match(lhs: &Value, rhs: &Value) -> bool { + if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) { + return lhs_type == rhs_type; + } + + lhs == rhs +} + +fn native_tool_kind(value: &Value) -> Option<&str> { + match value { + Value::String(kind) => Some(kind), + Value::Object(map) => map.get("type").and_then(Value::as_str), + _ => None, + } +} + /// Additional parameters to send to the Groq API #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct GroqAdditionalParameters { diff --git a/rig/rig-core/src/providers/openai/responses_api/mod.rs b/rig/rig-core/src/providers/openai/responses_api/mod.rs index 2c2996987..cbea84791 100644 --- a/rig/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig/rig-core/src/providers/openai/responses_api/mod.rs @@ -56,7 +56,8 @@ pub struct CompletionRequest { /// If none provided, the default option is "auto". #[serde(skip_serializing_if = "Option::is_none")] tool_choice: Option, - /// The tools you want to use. Currently this is limited to functions, but will be expanded on in future. + /// The tools you want to use. This supports both function tools and hosted tools + /// such as `web_search`, `file_search`, and `computer_use`. #[serde(skip_serializing_if = "Vec::is_empty")] pub tools: Vec, /// Additional parameters @@ -79,6 +80,25 @@ impl CompletionRequest { self } + + /// Adds a provider-native hosted tool (e.g. `web_search`, `file_search`, `computer_use`) + /// to the request. These tools are executed by OpenAI's infrastructure, not by Rig's + /// agent loop. + pub fn with_tool(mut self, tool: impl Into) -> Self { + self.tools.push(tool.into()); + self + } + + /// Adds multiple provider-native hosted tools to the request. These tools are executed + /// by OpenAI's infrastructure, not by Rig's agent loop. + pub fn with_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + Tool: Into, + { + self.tools.extend(tools.into_iter().map(Into::into)); + self + } } /// An input item for [`CompletionRequest`]. @@ -480,38 +500,106 @@ fn openai_reasoning_from_core( } /// The definition of a tool response, repurposed for OpenAI's Responses API. -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct ResponsesToolDefinition { + /// The type of tool. + #[serde(rename = "type")] + pub kind: String, /// Tool name + #[serde(default, skip_serializing_if = "String::is_empty")] pub name: String, /// Parameters - this should be a JSON schema. Tools should additionally ensure an "additionalParameters" field has been added with the value set to false, as this is required if using OpenAI's strict mode (enabled by default). + #[serde(default, skip_serializing_if = "is_json_null")] pub parameters: serde_json::Value, /// Whether to use strict mode. Enabled by default as it allows for improved efficiency. + #[serde(default, skip_serializing_if = "is_false")] pub strict: bool, - /// The type of tool. This should always be "function". - #[serde(rename = "type")] - pub kind: String, /// Tool description. + #[serde(default, skip_serializing_if = "String::is_empty")] pub description: String, + /// Additional provider-specific configuration for hosted tools. + #[serde(flatten, default, skip_serializing_if = "Map::is_empty")] + pub config: Map, } -impl From for ResponsesToolDefinition { - fn from(value: completion::ToolDefinition) -> Self { - let completion::ToolDefinition { - name, - mut parameters, - description, - } = value; +fn is_json_null(value: &Value) -> bool { + value.is_null() +} +fn is_false(value: &bool) -> bool { + !value +} + +impl ResponsesToolDefinition { + /// Creates a function tool definition. + pub fn function( + name: impl Into, + description: impl Into, + mut parameters: serde_json::Value, + ) -> Self { super::sanitize_schema(&mut parameters); Self { - name, - parameters, - description, kind: "function".to_string(), + name: name.into(), + parameters, strict: true, + description: description.into(), + config: Map::new(), + } + } + + /// Creates a hosted tool definition for an arbitrary hosted tool type. + pub fn hosted(kind: impl Into) -> Self { + Self { + kind: kind.into(), + name: String::new(), + parameters: Value::Null, + strict: false, + description: String::new(), + config: Map::new(), + } + } + + /// Creates a hosted `web_search` tool definition. + pub fn web_search() -> Self { + Self::hosted("web_search") + } + + /// Creates a hosted `file_search` tool definition. + pub fn file_search() -> Self { + Self::hosted("file_search") + } + + /// Creates a hosted `computer_use` tool definition. + pub fn computer_use() -> Self { + Self::hosted("computer_use") + } + + /// Adds hosted-tool configuration fields. + pub fn with_config(mut self, key: impl Into, value: Value) -> Self { + self.config.insert(key.into(), value); + self + } + + fn normalize(mut self) -> Self { + if self.kind == "function" { + super::sanitize_schema(&mut self.parameters); + self.strict = true; } + self + } +} + +impl From for ResponsesToolDefinition { + fn from(value: completion::ToolDefinition) -> Self { + let completion::ToolDefinition { + name, + parameters, + description, + } = value; + + Self::function(name, description, parameters) } } @@ -655,7 +743,7 @@ pub enum ResponseStatus { impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionRequest { type Error = CompletionError; fn try_from( - (model, req): (String, crate::completion::CompletionRequest), + (model, mut req): (String, crate::completion::CompletionRequest), ) -> Result { let model = req.model.clone().unwrap_or(model); let input = { @@ -687,21 +775,51 @@ impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionReque ) })?; - let stream = req - .additional_params - .clone() - .unwrap_or(Value::Null) - .as_bool(); + let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null); + let stream = match &additional_params_payload { + Value::Bool(stream) => Some(*stream), + Value::Object(map) => map.get("stream").and_then(Value::as_bool), + _ => None, + }; - let mut additional_parameters = if let Some(map) = req.additional_params { - serde_json::from_value::(map).map_err(|err| { - CompletionError::RequestError( - format!("Invalid OpenAI Responses additional_params payload: {err}").into(), + let mut additional_tools = Vec::new(); + if let Some(additional_params_map) = additional_params_payload.as_object_mut() { + if let Some(raw_tools) = additional_params_map.remove("tools") { + additional_tools = serde_json::from_value::>( + raw_tools, ) - })? - } else { + .map_err(|err| { + CompletionError::RequestError( + format!( + "Invalid OpenAI Responses tools payload in additional_params: {err}" + ) + .into(), + ) + })?; + } + additional_params_map.remove("stream"); + } + + if additional_params_payload.is_boolean() { + additional_params_payload = Value::Null; + } + + additional_tools = additional_tools + .into_iter() + .map(ResponsesToolDefinition::normalize) + .collect(); + + let mut additional_parameters = if additional_params_payload.is_null() { // If there's no additional parameters, initialise an empty object AdditionalParameters::default() + } else { + serde_json::from_value::(additional_params_payload).map_err( + |err| { + CompletionError::RequestError( + format!("Invalid OpenAI Responses additional_params payload: {err}").into(), + ) + }, + )? }; if additional_parameters.reasoning.is_some() { let include = additional_parameters.include.get_or_insert_with(Vec::new); @@ -729,6 +847,12 @@ impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionReque } let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?; + let mut tools: Vec = req + .tools + .into_iter() + .map(ResponsesToolDefinition::from) + .collect(); + tools.append(&mut additional_tools); Ok(Self { input, @@ -737,11 +861,7 @@ impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionReque max_output_tokens: req.max_tokens, stream, tool_choice, - tools: req - .tools - .into_iter() - .map(ResponsesToolDefinition::from) - .collect(), + tools, temperature: req.temperature, additional_parameters, }) @@ -755,6 +875,8 @@ pub struct ResponsesCompletionModel { pub(crate) client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, + /// Model-level default tools that are always added to outgoing requests. + pub tools: Vec, } impl ResponsesCompletionModel @@ -766,6 +888,7 @@ where Self { client, model: model.into(), + tools: Vec::new(), } } @@ -773,9 +896,26 @@ where Self { client, model: model.to_string(), + tools: Vec::new(), } } + /// Adds a default tool to all requests from this model. + pub fn with_tool(mut self, tool: impl Into) -> Self { + self.tools.push(tool.into()); + self + } + + /// Adds default tools to all requests from this model. + pub fn with_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + Tool: Into, + { + self.tools.extend(tools.into_iter().map(Into::into)); + self + } + /// Use the Completions API instead of Responses. pub fn completions_api(self) -> crate::providers::openai::completion::CompletionModel { super::completion::CompletionModel::with_model(self.client.completions_api(), &self.model) @@ -786,7 +926,8 @@ where &self, completion_request: crate::completion::CompletionRequest, ) -> Result { - let req = CompletionRequest::try_from((self.model.clone(), completion_request))?; + let mut req = CompletionRequest::try_from((self.model.clone(), completion_request))?; + req.tools.extend(self.tools.clone()); Ok(req) } diff --git a/rig/rig-core/src/providers/xai/completion.rs b/rig/rig-core/src/providers/xai/completion.rs index 1eabe6422..3e1537e53 100644 --- a/rig/rig-core/src/providers/xai/completion.rs +++ b/rig/rig-core/src/providers/xai/completion.rs @@ -4,6 +4,7 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; +use serde_json::Value; use tracing::{Instrument, Level, enabled, info_span}; use super::api::{ApiResponse, Message, ToolDefinition}; @@ -39,7 +40,7 @@ pub(super) struct XAICompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] max_output_tokens: Option, #[serde(skip_serializing_if = "Vec::is_empty")] - tools: Vec, + tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] tool_choice: Option, #[serde(flatten, skip_serializing_if = "Option::is_none")] @@ -54,6 +55,7 @@ impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest { tracing::warn!("Structured outputs currently not supported for xAI"); } let model = req.model.clone().unwrap_or_else(|| model.to_string()); + let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null); let mut input: Vec = req .preamble .as_ref() @@ -65,7 +67,20 @@ impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest { } let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?; - let tools = req.tools.into_iter().map(ToolDefinition::from).collect(); + let mut additional_tools = + extract_tools_from_additional_params(&mut additional_params_payload)?; + let mut tools = req + .tools + .into_iter() + .map(ToolDefinition::from) + .map(serde_json::to_value) + .collect::, _>>()?; + tools.append(&mut additional_tools); + let additional_params = if additional_params_payload.is_null() { + None + } else { + Some(additional_params_payload) + }; Ok(Self { model: model.to_string(), @@ -74,11 +89,27 @@ impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest { max_output_tokens: req.max_tokens, tools, tool_choice, - additional_params: req.additional_params, + additional_params, }) } } +fn extract_tools_from_additional_params( + additional_params: &mut Value, +) -> Result, CompletionError> { + if let Some(map) = additional_params.as_object_mut() + && let Some(raw_tools) = map.remove("tools") + { + return serde_json::from_value::>(raw_tools).map_err(|err| { + CompletionError::RequestError( + format!("Invalid xAI `additional_params.tools` payload: {err}").into(), + ) + }); + } + + Ok(Vec::new()) +} + // ================================================================ // Response Types // ================================================================ diff --git a/rig/rig-core/tests/openai_responses_input_item.rs b/rig/rig-core/tests/openai_responses_input_item.rs index 90c8b3343..8bf58f79d 100644 --- a/rig/rig-core/tests/openai_responses_input_item.rs +++ b/rig/rig-core/tests/openai_responses_input_item.rs @@ -342,7 +342,7 @@ fn openai_responses_invalid_additional_params_returns_error_without_panicking() temperature: None, max_tokens: None, tool_choice: None, - additional_params: Some(serde_json::json!(true)), + additional_params: Some(serde_json::json!("not_a_valid_object")), model: None, output_schema: None, };