Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion rig/rig-core/src/completion/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,36 @@ pub struct ToolDefinition {
pub parameters: serde_json::Value,
}

/// Provider-native tool definition.
///
/// Stored under `additional_params.tools` and forwarded by providers that support
/// provider-managed tools.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ProviderToolDefinition {
/// Tool type/kind name as expected by the target provider (for example `web_search`).
#[serde(rename = "type")]
pub kind: String,
/// Additional provider-specific configuration for this hosted tool.
#[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
pub config: serde_json::Map<String, serde_json::Value>,
}

impl ProviderToolDefinition {
/// Creates a provider-hosted tool definition by type.
pub fn new(kind: impl Into<String>) -> Self {
Self {
kind: kind.into(),
config: serde_json::Map::new(),
}
}

/// Adds a provider-specific configuration key/value.
pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.config.insert(key.into(), value);
self
}
}

// ================================================================
// Implementations
// ================================================================
Expand Down Expand Up @@ -606,6 +636,56 @@ impl CompletionRequest {
content: OneOrMany::many(messages).expect("There will be atleast one document"),
})
}

/// Adds a provider-hosted tool by storing it in `additional_params.tools`.
pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
self.additional_params =
merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
self
}

/// Adds provider-hosted tools by storing them in `additional_params.tools`.
pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
self.additional_params =
merge_provider_tools_into_additional_params(self.additional_params, tools);
self
}
}

fn merge_provider_tools_into_additional_params(
additional_params: Option<serde_json::Value>,
provider_tools: Vec<ProviderToolDefinition>,
) -> Option<serde_json::Value> {
if provider_tools.is_empty() {
return additional_params;
}

let mut provider_tools_json = provider_tools
.into_iter()
.map(|ProviderToolDefinition { kind, mut config }| {
// Force the provider tool type from the strongly-typed field.
config.insert("type".to_string(), serde_json::Value::String(kind));
serde_json::Value::Object(config)
})
.collect::<Vec<_>>();

let mut params_map = match additional_params {
Some(serde_json::Value::Object(map)) => map,
Some(serde_json::Value::Bool(stream)) => {
let mut map = serde_json::Map::new();
map.insert("stream".to_string(), serde_json::Value::Bool(stream));
map
}
_ => serde_json::Map::new(),
};

let mut merged_tools = match params_map.remove("tools") {
Some(serde_json::Value::Array(existing)) => existing,
_ => Vec::new(),
};
merged_tools.append(&mut provider_tools_json);
params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
Some(serde_json::Value::Object(params_map))
}

/// Builder struct for constructing a completion request.
Expand Down Expand Up @@ -660,6 +740,7 @@ pub struct CompletionRequestBuilder<M: CompletionModel> {
chat_history: Vec<Message>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
provider_tools: Vec<ProviderToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
tool_choice: Option<ToolChoice>,
Expand All @@ -677,6 +758,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
chat_history: Vec::new(),
documents: Vec::new(),
tools: Vec::new(),
provider_tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
Expand Down Expand Up @@ -747,6 +829,19 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
.fold(self, |builder, tool| builder.tool(tool))
}

/// Adds a provider-hosted tool to the completion request.
pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
self.provider_tools.push(tool);
self
}

/// Adds provider-hosted tools to the completion request.
pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
tools
.into_iter()
.fold(self, |builder, tool| builder.provider_tool(tool))
}

/// Adds additional parameters to the completion request.
/// This can be used to set additional provider-specific parameters. For example,
/// Cohere's completion models accept a `connectors` parameter that can be used to
Expand Down Expand Up @@ -831,6 +926,10 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
pub fn build(self) -> CompletionRequest {
let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
.expect("There will always be atleast the prompt");
let additional_params = merge_provider_tools_into_additional_params(
self.additional_params,
self.provider_tools,
);

CompletionRequest {
model: self.request_model,
Expand All @@ -841,7 +940,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
temperature: self.temperature,
max_tokens: self.max_tokens,
tool_choice: self.tool_choice,
additional_params: self.additional_params,
additional_params,
output_schema: self.output_schema,
}
}
Expand Down
39 changes: 34 additions & 5 deletions rig/rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ struct AnthropicCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<ToolDefinition>,
tools: Vec<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
output_config: Option<OutputConfig>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -1052,7 +1052,7 @@ impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
let AnthropicRequestParams {
model,
request: req,
request: mut req,
prompt_caching,
} = params;

Expand All @@ -1074,15 +1074,24 @@ impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;

let tools = req
let mut additional_params_payload = req
.additional_params
.take()
.unwrap_or(serde_json::Value::Null);
let mut additional_tools =
extract_tools_from_additional_params(&mut additional_params_payload)?;

let mut tools = req
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>();
.map(serde_json::to_value)
.collect::<Result<Vec<_>, _>>()?;
tools.append(&mut additional_tools);

// Convert system prompt to array format for cache_control support
let mut system = if let Some(preamble) = req.preamble {
Expand Down Expand Up @@ -1123,11 +1132,31 @@ impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
tools,
output_config,
additional_params: req.additional_params,
additional_params: if additional_params_payload.is_null() {
None
} else {
Some(additional_params_payload)
},
})
}
}

fn extract_tools_from_additional_params(
additional_params: &mut serde_json::Value,
) -> Result<Vec<serde_json::Value>, CompletionError> {
if let Some(map) = additional_params.as_object_mut()
&& let Some(raw_tools) = map.remove("tools")
{
return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
CompletionError::RequestError(
format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
)
});
}

Ok(Vec::new())
}

impl<T> completion::CompletionModel for CompletionModel<T>
where
T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
Expand Down
55 changes: 41 additions & 14 deletions rig/rig-core/src/providers/anthropic/streaming.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_stream::stream;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{Value, json};
use tracing::{Level, enabled, info_span};
use tracing_futures::Instrument;

Expand Down Expand Up @@ -127,7 +127,7 @@ where
{
pub(crate) async fn stream(
&self,
completion_request: CompletionRequest,
mut completion_request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
{
let request_model = completion_request
Expand Down Expand Up @@ -209,26 +209,37 @@ where
merge_inplace(&mut body, json!({ "temperature": temperature }));
}

if !completion_request.tools.is_empty() {
let mut additional_params_payload = completion_request
.additional_params
.take()
.unwrap_or(Value::Null);
let mut additional_tools =
extract_tools_from_additional_params(&mut additional_params_payload)?;

let mut tools = completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.map(serde_json::to_value)
.collect::<Result<Vec<_>, _>>()?;
tools.append(&mut additional_tools);

if !tools.is_empty() {
merge_inplace(
&mut body,
json!({
"tools": completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>(),
"tools": tools,
"tool_choice": ToolChoice::Auto,
}),
);
}

if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut body, params.clone())
if !additional_params_payload.is_null() {
merge_inplace(&mut body, additional_params_payload)
}

if enabled!(Level::TRACE) {
Expand Down Expand Up @@ -325,6 +336,22 @@ where
}
}

fn extract_tools_from_additional_params(
additional_params: &mut Value,
) -> Result<Vec<Value>, CompletionError> {
if let Some(map) = additional_params.as_object_mut()
&& let Some(raw_tools) = map.remove("tools")
{
return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
CompletionError::RequestError(
format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
)
});
}

Ok(Vec::new())
}

fn handle_event(
event: &StreamingEvent,
current_tool_call: &mut Option<ToolCallState>,
Expand Down
Loading
Loading