From 2502402bef36a9e61c642883f8c92238413da943 Mon Sep 17 00:00:00 2001 From: aniketfuryrocks Date: Mon, 5 Jan 2026 13:04:37 +0530 Subject: [PATCH 1/2] feat: implement cheap clone via Arc for all LLM backends This change wraps backend configurations in Arc to enable O(1) cloning instead of expensive deep copies. All backend structs now derive Clone. Changes: - Wrap config fields in Arc<*Config> for all backends - Make config structs and fields public for direct access - Make client field public on all backends - Add with_client() constructors for HTTP client injection/pooling - Add doc comments to all config struct fields - Remove redundant accessor methods (use direct field access instead) Backends updated: - Anthropic, Google, Ollama, DeepSeek, XAI, AzureOpenAI, Phind, ElevenLabs - OpenAICompatibleProvider (used by OpenAI, Groq, Mistral, Cohere, OpenRouter, HuggingFace) Breaking changes: - Previously private config fields are now accessed via .config.field - client() getter removed, use .client directly Draft assisted with OpenCode. --- src/backends/anthropic.rs | 214 ++++++++++++++++++++++------- src/backends/azure_openai.rs | 166 ++++++++++++++++------ src/backends/cohere.rs | 14 +- src/backends/deepseek.rs | 73 ++++++++-- src/backends/elevenlabs.rs | 89 +++++++----- src/backends/google.rs | 176 +++++++++++++++--------- src/backends/groq.rs | 6 +- src/backends/huggingface.rs | 6 +- src/backends/mistral.rs | 18 +-- src/backends/ollama.rs | 120 +++++++++++----- src/backends/openai.rs | 114 +++++++-------- src/backends/openrouter.rs | 6 +- src/backends/phind.rs | 80 ++++++++--- src/backends/xai.rs | 199 ++++++++++++++++----------- src/providers/openai_compatible.rs | 187 ++++++++++++++++++------- 15 files changed, 992 insertions(+), 476 deletions(-) diff --git a/src/backends/anthropic.rs b/src/backends/anthropic.rs index 89fb600..80b704d 100644 --- a/src/backends/anthropic.rs +++ b/src/backends/anthropic.rs @@ -3,6 +3,7 @@ //! This module provides integration with Anthropic's Claude models through their API. use std::collections::HashMap; +use std::sync::Arc; use crate::{ builder::LLMBackend, @@ -26,24 +27,51 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Client for interacting with Anthropic's API. +/// Configuration for the Anthropic client. /// -/// Provides methods for chat and completion requests using Anthropic's models. +/// This struct holds all configuration options and is wrapped in an `Arc` +/// to enable cheap cloning of the `Anthropic` client. #[derive(Debug)] -pub struct Anthropic { +pub struct AnthropicConfig { + /// API key for authentication with Anthropic. pub api_key: String, + /// Model identifier (e.g., "claude-3-sonnet-20240229"). pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: u32, + /// Sampling temperature for response randomness. pub temperature: f32, + /// Request timeout in seconds. pub timeout_seconds: u64, + /// System prompt to guide model behavior. pub system: String, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, + /// Available tools for the model to use. pub tools: Option>, + /// Tool choice configuration. pub tool_choice: Option, + /// Whether extended thinking is enabled. pub reasoning: bool, + /// Budget tokens for extended thinking. pub thinking_budget_tokens: Option, - client: Client, +} + +/// Client for interacting with Anthropic's API. +/// +/// Provides methods for chat and completion requests using Anthropic's models. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap +/// (only an atomic reference count increment). This allows sharing a single +/// client across multiple tasks without expensive deep copies. +#[derive(Debug, Clone)] +pub struct Anthropic { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Anthropic-specific tool format that matches their API structure @@ -469,7 +497,6 @@ impl Anthropic { /// * `temperature` - Sampling temperature (defaults to 0.7) /// * `timeout_seconds` - Request timeout in seconds (defaults to 30) /// * `system` - System prompt (defaults to "You are a helpful assistant.") - /// * /// * `thinking_budget_tokens` - Budget tokens for thinking (optional) #[allow(clippy::too_many_arguments)] pub fn new( @@ -486,26 +513,109 @@ impl Anthropic { reasoning: Option, thinking_budget_tokens: Option, ) -> Self { + let timeout = timeout_seconds.unwrap_or(30); let mut builder = Client::builder(); - if let Some(sec) = timeout_seconds { - builder = builder.timeout(std::time::Duration::from_secs(sec)); + if timeout > 0 { + builder = builder.timeout(std::time::Duration::from_secs(timeout)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()), - max_tokens: max_tokens.unwrap_or(300), - temperature: temperature.unwrap_or(0.7), - system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()), - timeout_seconds: timeout_seconds.unwrap_or(30), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, + max_tokens, + temperature, + timeout_seconds, + system, top_p, top_k, tools, tool_choice, - reasoning: reasoning.unwrap_or(false), + reasoning, thinking_budget_tokens, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Anthropic client with a custom HTTP client. + /// + /// This constructor allows sharing a pre-configured `reqwest::Client` across + /// multiple provider instances, enabling connection pooling and custom + /// HTTP settings. + /// + /// # Arguments + /// + /// * `client` - A pre-configured `reqwest::Client` for HTTP requests + /// * `api_key` - Anthropic API key for authentication + /// * `model` - Model identifier (defaults to "claude-3-sonnet-20240229") + /// * `max_tokens` - Maximum tokens in response (defaults to 300) + /// * `temperature` - Sampling temperature (defaults to 0.7) + /// * `timeout_seconds` - Request timeout in seconds (defaults to 30) + /// * `system` - System prompt (defaults to "You are a helpful assistant.") + /// * `thinking_budget_tokens` - Budget tokens for thinking (optional) + /// + /// # Examples + /// + /// ```rust + /// use reqwest::Client; + /// use std::time::Duration; + /// + /// // Create a shared client with custom settings + /// let shared_client = Client::builder() + /// .timeout(Duration::from_secs(120)) + /// .build() + /// .unwrap(); + /// + /// // Use the shared client for multiple Anthropic instances + /// let anthropic = llm::backends::anthropic::Anthropic::with_client( + /// shared_client.clone(), + /// "your-api-key", + /// Some("claude-3-opus-20240229".to_string()), + /// Some(1000), + /// Some(0.7), + /// Some(120), + /// Some("You are a helpful assistant.".to_string()), + /// None, + /// None, + /// None, + /// None, + /// None, + /// None, + /// ); + /// ``` + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + reasoning: Option, + thinking_budget_tokens: Option, + ) -> Self { + Self { + config: Arc::new(AnthropicConfig { + api_key: api_key.into(), + model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()), + max_tokens: max_tokens.unwrap_or(300), + temperature: temperature.unwrap_or(0.7), + system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()), + timeout_seconds: timeout_seconds.unwrap_or(30), + top_p, + top_k, + tools, + tool_choice, + reasoning: reasoning.unwrap_or(false), + thinking_budget_tokens, + }), + client, } } + } #[async_trait] @@ -525,18 +635,18 @@ impl ChatProvider for Anthropic { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Anthropic API key".to_string())); } let anthropic_messages = Self::convert_messages_to_anthropic(messages); let (anthropic_tools, final_tool_choice) = - Self::prepare_tools_and_choice(tools, self.tools.as_deref(), &self.tool_choice); + Self::prepare_tools_and_choice(tools, self.config.tools.as_deref(), &self.config.tool_choice); - let thinking = if self.reasoning { + let thinking = if self.config.reasoning { Some(ThinkingConfig { thinking_type: "enabled".to_string(), - budget_tokens: self.thinking_budget_tokens.unwrap_or(16000), + budget_tokens: self.config.thinking_budget_tokens.unwrap_or(16000), }) } else { None @@ -544,13 +654,13 @@ impl ChatProvider for Anthropic { let req_body = AnthropicCompleteRequest { messages: anthropic_messages, - model: &self.model, - max_tokens: Some(self.max_tokens), - temperature: Some(self.temperature), - system: Some(&self.system), + model: &self.config.model, + max_tokens: Some(self.config.max_tokens), + temperature: Some(self.config.temperature), + system: Some(&self.config.system), stream: Some(false), - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: anthropic_tools, tool_choice: final_tool_choice, thinking, @@ -559,13 +669,13 @@ impl ChatProvider for Anthropic { let mut request = self .client .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .json(&req_body); - if self.timeout_seconds > 0 { - request = request.timeout(std::time::Duration::from_secs(self.timeout_seconds)); + if self.config.timeout_seconds > 0 { + request = request.timeout(std::time::Duration::from_secs(self.config.timeout_seconds)); } if log::log_enabled!(log::Level::Trace) { @@ -613,7 +723,7 @@ impl ChatProvider for Anthropic { messages: &[ChatMessage], ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Anthropic API key".to_string())); } @@ -671,13 +781,13 @@ impl ChatProvider for Anthropic { let req_body = AnthropicCompleteRequest { messages: anthropic_messages, - model: &self.model, - max_tokens: Some(self.max_tokens), - temperature: Some(self.temperature), - system: Some(&self.system), + model: &self.config.model, + max_tokens: Some(self.config.max_tokens), + temperature: Some(self.config.temperature), + system: Some(&self.config.system), stream: Some(true), - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: None, tool_choice: None, thinking: None, @@ -686,13 +796,13 @@ impl ChatProvider for Anthropic { let mut request = self .client .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .json(&req_body); - if self.timeout_seconds > 0 { - request = request.timeout(std::time::Duration::from_secs(self.timeout_seconds)); + if self.config.timeout_seconds > 0 { + request = request.timeout(std::time::Duration::from_secs(self.config.timeout_seconds)); } let response = request.send().await?; @@ -726,23 +836,23 @@ impl ChatProvider for Anthropic { tools: Option<&[Tool]>, ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Anthropic API key".to_string())); } let anthropic_messages = Self::convert_messages_to_anthropic(messages); let (anthropic_tools, final_tool_choice) = - Self::prepare_tools_and_choice(tools, self.tools.as_deref(), &self.tool_choice); + Self::prepare_tools_and_choice(tools, self.config.tools.as_deref(), &self.config.tool_choice); let req_body = AnthropicCompleteRequest { messages: anthropic_messages, - model: &self.model, - max_tokens: Some(self.max_tokens), - temperature: Some(self.temperature), - system: Some(&self.system), + model: &self.config.model, + max_tokens: Some(self.config.max_tokens), + temperature: Some(self.config.temperature), + system: Some(&self.config.system), stream: Some(true), - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: anthropic_tools, tool_choice: final_tool_choice, thinking: None, // Thinking not supported with streaming tools @@ -751,13 +861,13 @@ impl ChatProvider for Anthropic { let mut request = self .client .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .json(&req_body); - if self.timeout_seconds > 0 { - request = request.timeout(std::time::Duration::from_secs(self.timeout_seconds)); + if self.config.timeout_seconds > 0 { + request = request.timeout(std::time::Duration::from_secs(self.config.timeout_seconds)); } if log::log_enabled!(log::Level::Trace) { @@ -924,7 +1034,7 @@ impl ModelsProvider for Anthropic { let resp = self .client .get("https://api.anthropic.com/v1/models") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .send() @@ -938,7 +1048,7 @@ impl ModelsProvider for Anthropic { impl crate::LLMProvider for Anthropic { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } diff --git a/src/backends/azure_openai.rs b/src/backends/azure_openai.rs index 562f755..203b4eb 100644 --- a/src/backends/azure_openai.rs +++ b/src/backends/azure_openai.rs @@ -2,6 +2,8 @@ //! //! This module provides integration with Azure OpenAI's GPT models through their API. +use std::sync::Arc; + #[cfg(feature = "azure_openai")] use crate::{ builder::LLMBackend, @@ -24,29 +26,54 @@ use either::*; use reqwest::{Client, Url}; use serde::{Deserialize, Serialize}; -/// Client for interacting with Azure OpenAI's API. -/// -/// Provides methods for chat and completion requests using Azure OpenAI's models. -pub struct AzureOpenAI { +/// Configuration for the Azure OpenAI client. +#[derive(Debug)] +pub struct AzureOpenAIConfig { + /// API key for authentication. pub api_key: String, + /// API version string. pub api_version: String, + /// Base URL for API requests. pub base_url: Url, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, + /// Available tools for the model to use. pub tools: Option>, + /// Tool choice configuration. pub tool_choice: Option, - /// Embedding parameters + /// Encoding format for embeddings. pub embedding_encoding_format: Option, + /// Dimensions for embeddings. pub embedding_dimensions: Option, + /// Reasoning effort level. pub reasoning_effort: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - client: Client, +} + +/// Client for interacting with Azure OpenAI's API. +/// +/// Provides methods for chat and completion requests using Azure OpenAI's models. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct AzureOpenAI { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Individual message in an OpenAI chat conversation. @@ -356,29 +383,74 @@ impl AzureOpenAI { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - - let endpoint = endpoint.into(); - let deployment_id = deployment_id.into(); - - Self { - api_key: api_key.into(), - api_version: api_version.into(), - base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/")) - .expect("Failed to parse base Url"), - model: model.unwrap_or("gpt-3.5-turbo".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + api_version, + deployment_id, + endpoint, + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, - tools, - tool_choice, embedding_encoding_format, embedding_dimensions, - client: builder.build().expect("Failed to build reqwest Client"), + tools, + tool_choice, reasoning_effort, json_schema, + ) + } + + /// Creates a new Azure OpenAI client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + api_version: impl Into, + deployment_id: impl Into, + endpoint: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + tools: Option>, + tool_choice: Option, + reasoning_effort: Option, + json_schema: Option, + ) -> Self { + let endpoint = endpoint.into(); + let deployment_id = deployment_id.into(); + + Self { + config: Arc::new(AzureOpenAIConfig { + api_key: api_key.into(), + api_version: api_version.into(), + base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/")) + .expect("Failed to parse base Url"), + model: model.unwrap_or("gpt-3.5-turbo".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + tools, + tool_choice, + embedding_encoding_format, + embedding_dimensions, + reasoning_effort, + json_schema, + }), + client, } } } @@ -399,7 +471,7 @@ impl ChatProvider for AzureOpenAI { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing Azure OpenAI API key".to_string(), )); @@ -425,7 +497,7 @@ impl ChatProvider for AzureOpenAI { } } - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { openai_msgs.insert( 0, AzureOpenAIChatMessage { @@ -445,26 +517,26 @@ impl ChatProvider for AzureOpenAI { // Build the response format object let response_format: Option = - self.json_schema.clone().map(|s| s.into()); + self.config.json_schema.clone().map(|s| s.into()); - let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone()); + let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.config.tools.clone()); let request_tool_choice = if request_tools.is_some() { - self.tool_choice.clone() + self.config.tool_choice.clone() } else { None }; let body = AzureOpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: false, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: request_tools, tool_choice: request_tool_choice, - reasoning_effort: self.reasoning_effort.clone(), + reasoning_effort: self.config.reasoning_effort.clone(), response_format, }; @@ -475,20 +547,21 @@ impl ChatProvider for AzureOpenAI { } let mut url = self + .config .base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; url.query_pairs_mut() - .append_pair("api-version", &self.api_version); + .append_pair("api-version", &self.config.api_version); let mut request = self .client .post(url) - .header("api-key", &self.api_key) + .header("api-key", &self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -542,34 +615,36 @@ impl CompletionProvider for AzureOpenAI { #[async_trait] impl EmbeddingProvider for AzureOpenAI { async fn embed(&self, input: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing OpenAI API key".into())); } let emb_format = self + .config .embedding_encoding_format .clone() .unwrap_or_else(|| "float".to_string()); let body = OpenAIEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.clone(), input, encoding_format: Some(emb_format), - dimensions: self.embedding_dimensions, + dimensions: self.config.embedding_dimensions, }; let mut url = self + .config .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; url.query_pairs_mut() - .append_pair("api-version", &self.api_version); + .append_pair("api-version", &self.config.api_version); let resp = self .client .post(url) - .header("api-key", &self.api_key) + .header("api-key", &self.config.api_key) .json(&body) .send() .await? @@ -584,7 +659,7 @@ impl EmbeddingProvider for AzureOpenAI { impl LLMProvider for AzureOpenAI { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -612,23 +687,24 @@ impl ModelsProvider for AzureOpenAI { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing Azure OpenAI API key".to_string(), )); } let mut url = self + .config .base_url .join("models") .map_err(|e| LLMError::HttpError(e.to_string()))?; url.query_pairs_mut() - .append_pair("api-version", &self.api_version); + .append_pair("api-version", &self.config.api_version); - let mut request = self.client.get(url).header("api-key", &self.api_key); + let mut request = self.client.get(url).header("api-key", &self.config.api_key); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/backends/cohere.rs b/src/backends/cohere.rs index a835556..4a660e4 100644 --- a/src/backends/cohere.rs +++ b/src/backends/cohere.rs @@ -104,7 +104,7 @@ struct CohereEmbeddingResponse { impl LLMProvider for Cohere { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -136,26 +136,26 @@ impl SpeechToTextProvider for Cohere { #[async_trait] impl EmbeddingProvider for Cohere { async fn embed(&self, input: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Cohere API key".into())); } let body = CohereEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.to_owned(), input, - encoding_format: self.embedding_encoding_format.clone(), - dimensions: self.embedding_dimensions, + encoding_format: self.config.embedding_encoding_format.as_deref().map(|s| s.to_owned()), + dimensions: self.config.embedding_dimensions, }; let url = self - .base_url + .config.base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self .client .post(url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body) .send() .await? diff --git a/src/backends/deepseek.rs b/src/backends/deepseek.rs index 7ef83bf..b9e9a66 100644 --- a/src/backends/deepseek.rs +++ b/src/backends/deepseek.rs @@ -2,6 +2,8 @@ //! //! This module provides integration with DeepSeek's models through their API. +use std::sync::Arc; + use crate::chat::{ChatResponse, Tool}; #[cfg(feature = "deepseek")] use crate::{ @@ -21,14 +23,32 @@ use serde::{Deserialize, Serialize}; use crate::ToolCall; -pub struct DeepSeek { +/// Configuration for the DeepSeek client. +#[derive(Debug)] +pub struct DeepSeekConfig { + /// API key for authentication with DeepSeek. pub api_key: String, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, - client: Client, +} + +/// Client for interacting with DeepSeek's API. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct DeepSeek { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } #[derive(Serialize)] @@ -95,14 +115,37 @@ impl DeepSeek { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or("deepseek-chat".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, max_tokens, temperature, - system, timeout_seconds, - client: builder.build().expect("Failed to build reqwest Client"), + system, + ) + } + + /// Creates a new DeepSeek client with a custom HTTP client. + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + ) -> Self { + Self { + config: Arc::new(DeepSeekConfig { + api_key: api_key.into(), + model: model.unwrap_or("deepseek-chat".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + }), + client, } } } @@ -119,7 +162,7 @@ impl ChatProvider for DeepSeek { /// /// The provider's response text or an error async fn chat(&self, messages: &[ChatMessage]) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing DeepSeek API key".to_string())); } @@ -134,7 +177,7 @@ impl ChatProvider for DeepSeek { }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { deepseek_msgs.insert( 0, DeepSeekChatMessage { @@ -145,9 +188,9 @@ impl ChatProvider for DeepSeek { } let body = DeepSeekChatRequest { - model: &self.model, + model: &self.config.model, messages: deepseek_msgs, - temperature: self.temperature, + temperature: self.config.temperature, stream: false, }; @@ -160,10 +203,10 @@ impl ChatProvider for DeepSeek { let mut request = self .client .post("https://api.deepseek.com/v1/chat/completions") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -230,14 +273,14 @@ impl ModelsProvider for DeepSeek { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing DeepSeek API key".to_string())); } let resp = self .client .get("https://api.deepseek.com/v1/models") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/elevenlabs.rs b/src/backends/elevenlabs.rs index 07f3bd2..1d39efa 100644 --- a/src/backends/elevenlabs.rs +++ b/src/backends/elevenlabs.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::chat::{ChatMessage, ChatProvider, ChatResponse, Tool}; use crate::completion::{CompletionProvider, CompletionRequest, CompletionResponse}; use crate::embedding::EmbeddingProvider; @@ -12,23 +14,34 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; +/// Configuration for the ElevenLabs client. +#[derive(Debug)] +/// Configuration for the ElevenLabs client. +pub struct ElevenLabsConfig { + /// API key for authentication. + pub api_key: String, + /// Model identifier. + pub model_id: String, + /// Base URL for API requests. + pub base_url: String, + /// Request timeout in seconds. + pub timeout_seconds: Option, + /// Voice setting for TTS. + pub voice: Option, +} + /// ElevenLabs speech to text backend implementation /// /// This struct provides functionality for speech-to-text transcription using the ElevenLabs API. /// It implements various LLM provider traits but only supports speech-to-text functionality. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] pub struct ElevenLabs { - /// API key for ElevenLabs authentication - api_key: String, - /// Model identifier for speech-to-text - model_id: String, - /// Base URL for API requests - base_url: String, - /// Optional timeout duration in seconds - timeout_seconds: Option, - /// HTTP client for making requests - client: Client, - /// Voice ID to use for speech synthesis - voice: Option, + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Internal representation of a word from ElevenLabs API response @@ -91,14 +104,28 @@ impl ElevenLabs { base_url: String, timeout_seconds: Option, voice: Option, + ) -> Self { + Self::with_client(Client::new(), api_key, model_id, base_url, timeout_seconds, voice) + } + + /// Creates a new ElevenLabs instance with a custom HTTP client. + pub fn with_client( + client: Client, + api_key: String, + model_id: String, + base_url: String, + timeout_seconds: Option, + voice: Option, ) -> Self { Self { - api_key, - model_id, - base_url, - timeout_seconds, - client: Client::new(), - voice, + config: Arc::new(ElevenLabsConfig { + api_key, + model_id, + base_url, + timeout_seconds, + voice, + }), + client, } } } @@ -116,19 +143,19 @@ impl SpeechToTextProvider for ElevenLabs { /// * `Ok(String)` - Transcribed text /// * `Err(LLMError)` - Error if transcription fails async fn transcribe(&self, audio: Vec) -> Result { - let url = format!("{}/speech-to-text", self.base_url); + let url = format!("{}/speech-to-text", self.config.base_url); let part = reqwest::multipart::Part::bytes(audio).file_name("audio.wav"); let form = reqwest::multipart::Form::new() - .text("model_id", self.model_id.clone()) + .text("model_id", self.config.model_id.clone()) .part("file", part); let mut req = self .client .post(url) - .header("xi-api-key", &self.api_key) + .header("xi-api-key", &self.config.api_key) .multipart(form); - if let Some(t) = self.timeout_seconds { + if let Some(t) = self.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } @@ -169,9 +196,9 @@ impl SpeechToTextProvider for ElevenLabs { /// * `Ok(String)` - Transcribed text /// * `Err(LLMError)` - Error if transcription fails async fn transcribe_file(&self, file_path: &str) -> Result { - let url = format!("{}/speech-to-text", self.base_url); + let url = format!("{}/speech-to-text", self.config.base_url); let form = reqwest::multipart::Form::new() - .text("model_id", self.model_id.clone()) + .text("model_id", self.config.model_id.clone()) .file("file", file_path) .await .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -179,10 +206,10 @@ impl SpeechToTextProvider for ElevenLabs { let mut req = self .client .post(url) - .header("xi-api-key", &self.api_key) + .header("xi-api-key", &self.config.api_key) .multipart(form); - if let Some(t) = self.timeout_seconds { + if let Some(t) = self.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } @@ -278,25 +305,25 @@ impl TextToSpeechProvider for ElevenLabs { async fn speech(&self, text: &str) -> Result, LLMError> { let url = format!( "{}/text-to-speech/{}?output_format=mp3_44100_128", - self.base_url, - self.voice + self.config.base_url, + self.config.voice .clone() .unwrap_or("JBFqnCBsd6RMkjVDRZzb".to_string()) ); let body = serde_json::json!({ "text": text, - "model_id": self.model_id + "model_id": self.config.model_id }); let mut req = self .client .post(url) - .header("xi-api-key", &self.api_key) + .header("xi-api-key", &self.config.api_key) .header("Content-Type", "application/json") .json(&body); - if let Some(t) = self.timeout_seconds { + if let Some(t) = self.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } diff --git a/src/backends/google.rs b/src/backends/google.rs index 32fbb9d..fb07e21 100644 --- a/src/backends/google.rs +++ b/src/backends/google.rs @@ -40,6 +40,8 @@ //! } //! ``` +use std::sync::Arc; + use crate::{ builder::LLMBackend, chat::{ @@ -62,33 +64,44 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Client for interacting with Google's Gemini API. -/// -/// This struct holds the configuration and state needed to make requests to the Gemini API. -/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits. -pub struct Google { - /// API key for authentication with Google's API +/// Configuration for the Google Gemini client. +#[derive(Debug)] +pub struct GoogleConfig { + /// API key for authentication with Google. pub api_key: String, - /// Model identifier (e.g. "gemini-1.5-flash") + /// Model identifier (e.g., "gemini-pro"). pub model: String, - /// Maximum number of tokens to generate in responses + /// Maximum tokens to generate in responses. pub max_tokens: Option, - /// Sampling temperature between 0.0 and 1.0 + /// Sampling temperature for response randomness. pub temperature: Option, - /// Optional system prompt to set context + /// System prompt to guide model behavior. pub system: Option, - /// Request timeout in seconds + /// Request timeout in seconds. pub timeout_seconds: Option, - /// Top-p sampling parameter + /// Top-p (nucleus) sampling parameter. pub top_p: Option, - /// Top-k sampling parameter + /// Top-k sampling parameter. pub top_k: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - /// Available tools for function calling + /// Available tools for the model to use. pub tools: Option>, - /// HTTP client for making API requests - client: Client, +} + +/// Client for interacting with Google's Gemini API. +/// +/// This struct holds the configuration and state needed to make requests to the Gemini API. +/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap +/// (only an atomic reference count increment). +#[derive(Debug, Clone)] +pub struct Google { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Request body for chat completions @@ -474,7 +487,6 @@ impl Google { /// * `temperature` - Sampling temperature between 0.0 and 1.0 /// * `timeout_seconds` - Request timeout in seconds /// * `system` - System prompt to set context - /// * `stream` - Whether to stream responses /// * `top_p` - Top-p sampling parameter /// * `top_k` - Top-k sampling parameter /// * `json_schema` - JSON schema for structured output @@ -500,18 +512,50 @@ impl Google { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, json_schema, tools, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Google Gemini client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + json_schema: Option, + tools: Option>, + ) -> Self { + Self { + config: Arc::new(GoogleConfig { + api_key: api_key.into(), + model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + json_schema, + tools, + }), + client, } } } @@ -528,14 +572,14 @@ impl ChatProvider for Google { /// /// The model's response text or an error async fn chat(&self, messages: &[ChatMessage]) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let mut chat_contents = Vec::with_capacity(messages.len()); // Add system message if present - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_contents.push(GoogleChatContent { role: "user", parts: vec![GoogleContentPart::Text(system)], @@ -601,17 +645,17 @@ impl ChatProvider for Google { } // Remove generation_config if empty to avoid validation errors - let generation_config = if self.max_tokens.is_none() - && self.temperature.is_none() - && self.top_p.is_none() - && self.top_k.is_none() - && self.json_schema.is_none() + let generation_config = if self.config.max_tokens.is_none() + && self.config.temperature.is_none() + && self.config.top_p.is_none() + && self.config.top_k.is_none() + && self.config.json_schema.is_none() { None } else { // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly. - let (response_mime_type, response_schema) = if let Some(json_schema) = &self.json_schema + let (response_mime_type, response_schema) = if let Some(json_schema) = &self.config.json_schema { if let Some(schema) = &json_schema.schema { // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it @@ -627,10 +671,10 @@ impl ChatProvider for Google { (None, None) }; Some(GoogleGenerationConfig { - max_output_tokens: self.max_tokens, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, + max_output_tokens: self.config.max_tokens, + temperature: self.config.temperature, + top_p: self.config.top_p, + top_k: self.config.top_k, response_mime_type, response_schema, }) @@ -649,12 +693,12 @@ impl ChatProvider for Google { let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}", - model = self.model, - key = self.api_key + model = self.config.model, + key = self.config.api_key ); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -696,14 +740,14 @@ impl ChatProvider for Google { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let mut chat_contents = Vec::with_capacity(messages.len()); // Add system message if present - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_contents.push(GoogleChatContent { role: "user", parts: vec![GoogleContentPart::Text(system)], @@ -779,7 +823,7 @@ impl ChatProvider for Google { let generation_config = { // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly. - let (response_mime_type, response_schema) = if let Some(json_schema) = &self.json_schema + let (response_mime_type, response_schema) = if let Some(json_schema) = &self.config.json_schema { if let Some(schema) = &json_schema.schema { // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it @@ -798,10 +842,10 @@ impl ChatProvider for Google { }; Some(GoogleGenerationConfig { - max_output_tokens: self.max_tokens, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, + max_output_tokens: self.config.max_tokens, + temperature: self.config.temperature, + top_p: self.config.top_p, + top_k: self.config.top_k, response_mime_type, response_schema, }) @@ -821,14 +865,14 @@ impl ChatProvider for Google { let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}", - model = self.model, - key = self.api_key + model = self.config.model, + key = self.config.api_key ); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -906,11 +950,11 @@ impl ChatProvider for Google { std::pin::Pin> + Send>>, LLMError, > { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let mut chat_contents = Vec::with_capacity(messages.len()); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_contents.push(GoogleChatContent { role: "user", parts: vec![GoogleContentPart::Text(system)], @@ -941,18 +985,18 @@ impl ChatProvider for Google { }, }); } - let generation_config = if self.max_tokens.is_none() - && self.temperature.is_none() - && self.top_p.is_none() - && self.top_k.is_none() + let generation_config = if self.config.max_tokens.is_none() + && self.config.temperature.is_none() + && self.config.top_p.is_none() + && self.config.top_k.is_none() { None } else { Some(GoogleGenerationConfig { - max_output_tokens: self.max_tokens, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, + max_output_tokens: self.config.max_tokens, + temperature: self.config.temperature, + top_p: self.config.top_p, + top_k: self.config.top_k, response_mime_type: None, response_schema: None, }) @@ -965,12 +1009,12 @@ impl ChatProvider for Google { }; let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse&key={key}", - model = self.model, - key = self.api_key + model = self.config.model, + key = self.config.api_key ); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -1012,7 +1056,7 @@ impl CompletionProvider for Google { #[async_trait] impl EmbeddingProvider for Google { async fn embed(&self, texts: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } @@ -1029,7 +1073,7 @@ impl EmbeddingProvider for Google { let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}", - self.api_key + self.config.api_key ); let resp = self @@ -1058,7 +1102,7 @@ impl SpeechToTextProvider for Google { impl LLMProvider for Google { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -1222,13 +1266,13 @@ impl ModelsProvider for Google { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let url = format!( "https://generativelanguage.googleapis.com/v1beta/models?key={}", - self.api_key + self.config.api_key ); let resp = self.client.get(&url).send().await?.error_for_status()?; diff --git a/src/backends/groq.rs b/src/backends/groq.rs index 6a7a97f..8fd04fc 100644 --- a/src/backends/groq.rs +++ b/src/backends/groq.rs @@ -93,7 +93,7 @@ impl Groq { impl LLMProvider for Groq { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -133,7 +133,7 @@ impl ModelsProvider for Groq { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Groq API key".to_string())); } @@ -142,7 +142,7 @@ impl ModelsProvider for Groq { let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/huggingface.rs b/src/backends/huggingface.rs index 233ab85..0faa049 100644 --- a/src/backends/huggingface.rs +++ b/src/backends/huggingface.rs @@ -79,7 +79,7 @@ impl HuggingFace { impl LLMProvider for HuggingFace { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -119,7 +119,7 @@ impl ModelsProvider for HuggingFace { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing HuggingFace API key".to_string(), )); @@ -130,7 +130,7 @@ impl ModelsProvider for HuggingFace { let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/mistral.rs b/src/backends/mistral.rs index 761d915..60615b3 100644 --- a/src/backends/mistral.rs +++ b/src/backends/mistral.rs @@ -104,7 +104,7 @@ struct MistralEmbeddingResponse { impl LLMProvider for Mistral { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -136,26 +136,26 @@ impl SpeechToTextProvider for Mistral { #[async_trait] impl EmbeddingProvider for Mistral { async fn embed(&self, input: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Mistral API key".into())); } let body = MistralEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.to_owned(), input, - encoding_format: self.embedding_encoding_format.clone(), - dimensions: self.embedding_dimensions, + encoding_format: self.config.embedding_encoding_format.as_deref().map(|s| s.to_owned()), + dimensions: self.config.embedding_dimensions, }; let url = self - .base_url + .config.base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self .client .post(url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body) .send() .await? @@ -173,14 +173,14 @@ impl ModelsProvider for Mistral { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Mistral API key".to_string())); } let url = format!("{}models", MistralConfig::DEFAULT_BASE_URL); let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/ollama.rs b/src/backends/ollama.rs index cd71fec..d83345e 100644 --- a/src/backends/ollama.rs +++ b/src/backends/ollama.rs @@ -3,6 +3,7 @@ //! This module provides integration with Ollama's local LLM server through its API. use std::pin::Pin; +use std::sync::Arc; use crate::{ builder::LLMBackend, @@ -26,24 +27,44 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Client for interacting with Ollama's API. -/// -/// Provides methods for chat and completion requests using Ollama's models. -pub struct Ollama { +/// Configuration for the Ollama client. +#[derive(Debug)] +pub struct OllamaConfig { + /// Base URL for the Ollama API. pub base_url: String, + /// Optional API key for authentication. pub api_key: Option, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - /// Available tools for function calling + /// Available tools for the model to use. pub tools: Option>, - client: Client, +} + +/// Client for interacting with Ollama's API. +/// +/// Provides methods for chat and completion requests using Ollama's models. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct Ollama { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Request payload for Ollama's chat API endpoint. @@ -299,7 +320,6 @@ impl Ollama { /// * `temperature` - Sampling temperature /// * `timeout_seconds` - Request timeout in seconds /// * `system` - System prompt - /// * `stream` - Whether to stream responses /// * `json_schema` - JSON schema for structured output /// * `tools` - Function tools that the model can use #[allow(clippy::too_many_arguments)] @@ -321,19 +341,53 @@ impl Ollama { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - base_url: base_url.into(), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + base_url, api_key, - model: model.unwrap_or("llama3.1".to_string()), - temperature, + model, max_tokens, + temperature, timeout_seconds, system, top_p, top_k, json_schema, tools, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Ollama client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + base_url: impl Into, + api_key: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + json_schema: Option, + tools: Option>, + ) -> Self { + Self { + config: Arc::new(OllamaConfig { + base_url: base_url.into(), + api_key, + model: model.unwrap_or("llama3.1".to_string()), + temperature, + max_tokens, + timeout_seconds, + system, + top_p, + top_k, + json_schema, + tools, + }), + client, } } @@ -346,7 +400,7 @@ impl Ollama { let mut chat_messages: Vec = messages.iter().map(OllamaChatMessage::from).collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_messages.insert( 0, OllamaChatMessage { @@ -361,7 +415,7 @@ impl Ollama { let ollama_tools = tools.map(|t| t.iter().map(OllamaTool::from).collect()); // Ollama doesn't require the "name" field in the schema, so we just use the schema itself - let format = if let Some(schema) = &self.json_schema { + let format = if let Some(schema) = &self.config.json_schema { schema.schema.as_ref().map(|schema| OllamaResponseFormat { format: OllamaResponseType::StructuredOutput(schema.clone()), }) @@ -370,12 +424,12 @@ impl Ollama { }; OllamaChatRequest { - model: self.model.clone(), + model: self.config.model.clone(), messages: chat_messages, stream, options: Some(OllamaOptions { - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, }), format, tools: ollama_tools, @@ -390,7 +444,7 @@ impl ChatProvider for Ollama { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } @@ -402,11 +456,11 @@ impl ChatProvider for Ollama { } } - let url = format!("{}/api/chat", self.base_url); + let url = format!("{}/api/chat", self.config.base_url); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -426,10 +480,10 @@ impl ChatProvider for Ollama { ) -> Result> + Send>>, LLMError> { let req_body = self.make_chat_request(messages, None, true); - let url = format!("{}/api/chat", self.base_url); + let url = format!("{}/api/chat", self.config.base_url); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -454,13 +508,13 @@ impl CompletionProvider for Ollama { /// /// The completion response containing the generated text or an error async fn complete(&self, req: &CompletionRequest) -> Result { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } - let url = format!("{}/api/generate", self.base_url); + let url = format!("{}/api/generate", self.config.base_url); let req_body = OllamaGenerateRequest { - model: self.model.clone(), + model: self.config.model.clone(), prompt: &req.prompt, raw: true, stream: false, @@ -488,13 +542,13 @@ impl CompletionProvider for Ollama { #[async_trait] impl EmbeddingProvider for Ollama { async fn embed(&self, text: Vec) -> Result>, LLMError> { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } - let url = format!("{}/api/embed", self.base_url); + let url = format!("{}/api/embed", self.config.base_url); let body = OllamaEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.clone(), input: text, }; @@ -582,15 +636,15 @@ impl ModelsProvider for Ollama { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } - let url = format!("{}/api/tags", self.base_url); + let url = format!("{}/api/tags", self.config.base_url); let mut request = self.client.get(&url); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -602,7 +656,7 @@ impl ModelsProvider for Ollama { impl crate::LLMProvider for Ollama { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } diff --git a/src/backends/openai.rs b/src/backends/openai.rs index 026a9af..420ec41 100644 --- a/src/backends/openai.rs +++ b/src/backends/openai.rs @@ -287,11 +287,11 @@ impl ChatProvider for OpenAI { // Use the common prepare_messages method from the OpenAI-compatible provider let openai_msgs = self.provider.prepare_messages(messages); let response_format: Option = - self.provider.json_schema.clone().map(|s| s.into()); + self.provider.config.json_schema.as_ref().cloned().map(|s| s.into()); // Convert regular tools to OpenAI format let tool_calls = tools .map(|t| t.to_vec()) - .or_else(|| self.provider.tools.clone()); + .or_else(|| self.provider.config.tools.as_deref().map(|t| t.to_vec())); let mut openai_tools: Vec = Vec::new(); // Add regular function tools if let Some(tools) = &tool_calls { @@ -308,44 +308,44 @@ impl ChatProvider for OpenAI { Some(openai_tools) }; let request_tool_choice = if final_tools.is_some() { - self.provider.tool_choice.clone() + self.provider.config.tool_choice.as_ref().cloned() } else { None }; let body = OpenAIAPIChatRequest { - model: self.provider.model.as_str(), + model: &self.provider.config.model, messages: openai_msgs, input: None, - max_completion_tokens: self.provider.max_tokens, + max_completion_tokens: self.provider.config.max_tokens, max_output_tokens: None, - temperature: self.provider.temperature, + temperature: self.provider.config.temperature, stream: false, - top_p: self.provider.top_p, - top_k: self.provider.top_k, + top_p: self.provider.config.top_p, + top_k: self.provider.config.top_k, tools: final_tools, tool_choice: request_tool_choice, - reasoning_effort: self.provider.reasoning_effort.clone(), + reasoning_effort: self.provider.config.reasoning_effort.as_deref().map(|s| s.to_owned()), response_format, stream_options: None, - extra_body: self.provider.extra_body.clone(), + extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider - .base_url + .config.base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut request = self .provider .client .post(url) - .bearer_auth(&self.provider.api_key) + .bearer_auth(&self.provider.config.api_key) .json(&body); if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!("OpenAI request payload: {}", json); } } - if let Some(timeout) = self.provider.timeout_seconds { + if let Some(timeout) = self.provider.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -436,7 +436,7 @@ impl ChatProvider for OpenAI { > { let openai_msgs = self.provider.prepare_messages(messages); // Convert regular tools to OpenAI format for streaming - let openai_tools: Option> = self.provider.tools.as_ref().map(|tools| { + let openai_tools: Option> = self.provider.config.tools.as_deref().map(|tools| { tools .iter() .map(|tool| OpenAITool::Function { @@ -446,36 +446,36 @@ impl ChatProvider for OpenAI { .collect() }); let body = OpenAIAPIChatRequest { - model: &self.provider.model, + model: &self.provider.config.model, messages: openai_msgs, input: None, - max_completion_tokens: self.provider.max_tokens, + max_completion_tokens: self.provider.config.max_tokens, max_output_tokens: None, - temperature: self.provider.temperature, + temperature: self.provider.config.temperature, stream: true, - top_p: self.provider.top_p, - top_k: self.provider.top_k, + top_p: self.provider.config.top_p, + top_k: self.provider.config.top_k, tools: openai_tools, - tool_choice: self.provider.tool_choice.clone(), - reasoning_effort: self.provider.reasoning_effort.clone(), + tool_choice: self.provider.config.tool_choice.as_ref().cloned(), + reasoning_effort: self.provider.config.reasoning_effort.as_deref().map(|s| s.to_owned()), response_format: None, stream_options: Some(OpenAIStreamOptions { include_usage: true, }), - extra_body: self.provider.extra_body.clone(), + extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider - .base_url + .config.base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut request = self .provider .client .post(url) - .bearer_auth(&self.provider.api_key) + .bearer_auth(&self.provider.config.api_key) .json(&body); - if let Some(timeout) = self.provider.timeout_seconds { + if let Some(timeout) = self.provider.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -489,7 +489,7 @@ impl ChatProvider for OpenAI { } Ok(create_sse_stream( response, - self.provider.normalize_response, + self.provider.config.normalize_response, )) } @@ -523,24 +523,24 @@ impl SpeechToTextProvider for OpenAI { async fn transcribe_file(&self, file_path: &str) -> Result { let url = self - .base_url() + .provider.config.base_url .join("audio/transcriptions") .map_err(|e| LLMError::HttpError(e.to_string()))?; let form = reqwest::multipart::Form::new() - .text("model", self.model().to_string()) + .text("model", self.provider.config.model.to_string()) .text("response_format", "text") .file("file", file_path) .await .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut req = self - .client() + .provider.client .post(url) - .bearer_auth(self.api_key()) + .bearer_auth(&self.provider.config.api_key) .multipart(form); - if let Some(t) = self.timeout_seconds() { + if let Some(t) = self.provider.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } @@ -564,21 +564,21 @@ impl TextToSpeechProvider for OpenAI { impl EmbeddingProvider for OpenAI { async fn embed(&self, input: Vec) -> Result>, LLMError> { let body = OpenAIEmbeddingRequest { - model: self.model().to_string(), + model: self.provider.config.model.to_string(), input, - encoding_format: self.provider.embedding_encoding_format.clone(), - dimensions: self.provider.embedding_dimensions, + encoding_format: self.provider.config.embedding_encoding_format.as_deref().map(|s| s.to_owned()), + dimensions: self.provider.config.embedding_dimensions, }; let url = self - .base_url() + .provider.config.base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self - .client() + .provider.client .post(url) - .bearer_auth(self.api_key()) + .bearer_auth(&self.provider.config.api_key) .json(&body) .send() .await? @@ -597,14 +597,14 @@ impl ModelsProvider for OpenAI { _request: Option<&ModelListRequest>, ) -> Result, LLMError> { let url = self - .base_url() + .provider.config.base_url .join("models") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self - .client() + .provider.client .get(url) - .bearer_auth(self.api_key()) + .bearer_auth(&self.provider.config.api_key) .send() .await? .error_for_status()?; @@ -622,27 +622,27 @@ impl LLMProvider for OpenAI {} // Helper methods to access provider fields impl OpenAI { pub fn api_key(&self) -> &str { - &self.provider.api_key + &self.provider.config.api_key } pub fn model(&self) -> &str { - &self.provider.model + &self.provider.config.model } pub fn base_url(&self) -> &reqwest::Url { - &self.provider.base_url + &self.provider.config.base_url } pub fn timeout_seconds(&self) -> Option { - self.provider.timeout_seconds + self.provider.config.timeout_seconds } - pub fn client(&self) -> &reqwest::Client { + pub fn get_client(&self) -> &reqwest::Client { &self.provider.client } pub fn tools(&self) -> Option<&[Tool]> { - self.provider.tools.as_deref() + self.provider.config.tools.as_deref() } /// Chat with OpenAI-hosted tools using the `/responses` endpoint @@ -661,26 +661,26 @@ impl OpenAI { hosted_tools: Vec, ) -> Result, LLMError> { let body = OpenAIAPIChatRequest { - model: self.provider.model.as_str(), + model: &self.provider.config.model, messages: Vec::new(), // Empty for hosted tools input: Some(input), max_completion_tokens: None, - max_output_tokens: self.provider.max_tokens, - temperature: self.provider.temperature, + max_output_tokens: self.provider.config.max_tokens, + temperature: self.provider.config.temperature, stream: false, - top_p: self.provider.top_p, - top_k: self.provider.top_k, + top_p: self.provider.config.top_p, + top_k: self.provider.config.top_k, tools: Some(hosted_tools), - tool_choice: self.provider.tool_choice.clone(), - reasoning_effort: self.provider.reasoning_effort.clone(), + tool_choice: self.provider.config.tool_choice.as_ref().cloned(), + reasoning_effort: self.provider.config.reasoning_effort.as_deref().map(|s| s.to_owned()), response_format: None, // Hosted tools don't use structured output stream_options: None, - extra_body: self.provider.extra_body.clone(), + extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider - .base_url + .config.base_url .join("responses") // Use responses endpoint for hosted tools .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -688,7 +688,7 @@ impl OpenAI { .provider .client .post(url) - .bearer_auth(&self.provider.api_key) + .bearer_auth(&self.provider.config.api_key) .json(&body); if log::log_enabled!(log::Level::Trace) { @@ -697,7 +697,7 @@ impl OpenAI { } } - if let Some(timeout) = self.provider.timeout_seconds { + if let Some(timeout) = self.provider.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/backends/openrouter.rs b/src/backends/openrouter.rs index c0ca170..633400b 100644 --- a/src/backends/openrouter.rs +++ b/src/backends/openrouter.rs @@ -79,7 +79,7 @@ impl OpenRouter { impl LLMProvider for OpenRouter { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -119,7 +119,7 @@ impl ModelsProvider for OpenRouter { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing OpenRouter API key".to_string(), )); @@ -130,7 +130,7 @@ impl ModelsProvider for OpenRouter { let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/phind.rs b/src/backends/phind.rs index 4698bc8..8e18708 100644 --- a/src/backends/phind.rs +++ b/src/backends/phind.rs @@ -1,5 +1,7 @@ /// Implementation of the Phind LLM provider. /// This module provides integration with Phind's language model API. +use std::sync::Arc; + #[cfg(feature = "phind")] use crate::{ chat::{ChatMessage, ChatProvider, ChatRole}, @@ -21,26 +23,36 @@ use reqwest::StatusCode; use reqwest::{Client, Response}; use serde_json::{json, Value}; -/// Represents a Phind LLM client with configuration options. -pub struct Phind { - /// The model identifier to use (e.g. "Phind-70B") +/// Configuration for the Phind client. +#[derive(Debug)] +pub struct PhindConfig { + /// Model identifier. pub model: String, - /// Maximum number of tokens to generate + /// Maximum tokens to generate in responses. pub max_tokens: Option, - /// Temperature for controlling randomness (0.0-1.0) + /// Sampling temperature for response randomness. pub temperature: Option, - /// System prompt to prepend to conversations + /// System prompt to guide model behavior. pub system: Option, - /// Request timeout in seconds + /// Request timeout in seconds. pub timeout_seconds: Option, - /// Top-p sampling parameter + /// Top-p (nucleus) sampling parameter. pub top_p: Option, - /// Top-k sampling parameter + /// Top-k sampling parameter. pub top_k: Option, - /// Base URL for the Phind API + /// Base URL for API requests. pub api_base_url: String, - /// HTTP client for making requests - client: Client, +} + +/// Represents a Phind LLM client with configuration options. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct Phind { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } #[derive(Debug)] @@ -80,16 +92,42 @@ impl Phind { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - model: model.unwrap_or_else(|| "Phind-70B".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, - api_base_url: "https://https.extension.phind.com/agent/".to_string(), - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Phind client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + ) -> Self { + Self { + config: Arc::new(PhindConfig { + model: model.unwrap_or_else(|| "Phind-70B".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + api_base_url: "https://https.extension.phind.com/agent/".to_string(), + }), + client, } } @@ -189,7 +227,7 @@ impl ChatProvider for Phind { })); } - if let Some(system_prompt) = &self.system { + if let Some(system_prompt) = &self.config.system { message_history.insert( 0, json!({ @@ -204,7 +242,7 @@ impl ChatProvider for Phind { "allow_magic_buttons": true, "is_vscode_extension": true, "message_history": message_history, - "requested_model": self.model, + "requested_model": self.config.model, "user_input": messages .iter() .rev() @@ -220,11 +258,11 @@ impl ChatProvider for Phind { let headers = Self::create_headers()?; let mut request = self .client - .post(&self.api_base_url) + .post(&self.config.api_base_url) .headers(headers) .json(&payload); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/backends/xai.rs b/src/backends/xai.rs index 1028829..5f95631 100644 --- a/src/backends/xai.rs +++ b/src/backends/xai.rs @@ -3,6 +3,8 @@ //! This module provides integration with X.AI's models through their API. //! It implements chat and completion capabilities using the X.AI API endpoints. +use std::sync::Arc; + use crate::ToolCall; #[cfg(feature = "xai")] use crate::{ @@ -23,47 +25,58 @@ use futures::stream::Stream; use reqwest::Client; use serde::{Deserialize, Serialize}; -/// Client for interacting with X.AI's API. -/// -/// This struct provides methods for making chat and completion requests to X.AI's language models. -/// It handles authentication, request configuration, and response parsing. -pub struct XAI { - /// API key for authentication with X.AI services +/// Configuration for the XAI client. +/// Configuration for the X.AI client. +#[derive(Debug)] +pub struct XAIConfig { + /// API key for authentication with X.AI. pub api_key: String, - /// Model identifier to use for requests (e.g. "grok-2-latest") + /// Model identifier. pub model: String, - /// Maximum number of tokens to generate in responses + /// Maximum tokens to generate in responses. pub max_tokens: Option, - /// Temperature parameter for controlling response randomness (0.0 to 1.0) + /// Sampling temperature for response randomness. pub temperature: Option, - /// Optional system prompt to provide context + /// System prompt to guide model behavior. pub system: Option, - /// Request timeout duration in seconds + /// Request timeout in seconds. pub timeout_seconds: Option, - /// Top-p sampling parameter for controlling response diversity + /// Top-p (nucleus) sampling parameter. pub top_p: Option, - /// Top-k sampling parameter for controlling response diversity + /// Top-k sampling parameter. pub top_k: Option, - /// Embedding encoding format + /// Encoding format for embeddings. pub embedding_encoding_format: Option, - /// Embedding dimensions + /// Dimensions for embeddings. pub embedding_dimensions: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - /// XAI search parameters + /// Search mode for web search functionality. pub xai_search_mode: Option, - /// XAI search sources + /// Source type for search. pub xai_search_source_type: Option, - /// XAI search excluded websites + /// Websites to exclude from search. pub xai_search_excluded_websites: Option>, - /// XAI search max results + /// Maximum number of search results. pub xai_search_max_results: Option, - /// XAI search from date + /// Start date for search results. pub xai_search_from_date: Option, - /// XAI search to date + /// End date for search results. pub xai_search_to_date: Option, - /// HTTP client for making API requests - client: Client, +} + +/// Client for interacting with X.AI's API. +/// +/// This struct provides methods for making chat and completion requests to X.AI's language models. +/// It handles authentication, request configuration, and response parsing. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct XAI { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Search source configuration for search parameters @@ -236,24 +249,6 @@ struct XAIResponseFormat { impl XAI { /// Creates a new X.AI client with the specified configuration. - /// - /// # Arguments - /// - /// * `api_key` - Authentication key for X.AI API access - /// * `model` - Model identifier (defaults to "grok-2-latest" if None) - /// * `max_tokens` - Maximum number of tokens to generate in responses - /// * `temperature` - Sampling temperature for controlling randomness - /// * `timeout_seconds` - Request timeout duration in seconds - /// * `system` - System prompt for providing context - /// * `stream` - Whether to enable streaming responses - /// * `top_p` - Top-p sampling parameter - /// * `top_k` - Top-k sampling parameter - /// * `json_schema` - JSON schema for structured output - /// * `search_parameters` - Search parameters for search functionality - /// - /// # Returns - /// - /// A configured X.AI client instance ready to make API requests. #[allow(clippy::too_many_arguments)] pub fn new( api_key: impl Into, @@ -278,13 +273,14 @@ impl XAI { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or("grok-2-latest".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, embedding_encoding_format, @@ -296,7 +292,52 @@ impl XAI { xai_search_max_results, xai_search_from_date, xai_search_to_date, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new X.AI client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + json_schema: Option, + xai_search_mode: Option, + xai_search_source_type: Option, + xai_search_excluded_websites: Option>, + xai_search_max_results: Option, + xai_search_from_date: Option, + xai_search_to_date: Option, + ) -> Self { + Self { + config: Arc::new(XAIConfig { + api_key: api_key.into(), + model: model.unwrap_or("grok-2-latest".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + embedding_encoding_format, + embedding_dimensions, + json_schema, + xai_search_mode, + xai_search_source_type, + xai_search_excluded_websites, + xai_search_max_results, + xai_search_from_date, + xai_search_to_date, + }), + client, } } } @@ -313,7 +354,7 @@ impl ChatProvider for XAI { /// /// The generated response text, or an error if the request fails. async fn chat(&self, messages: &[ChatMessage]) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".to_string())); } @@ -328,7 +369,7 @@ impl ChatProvider for XAI { }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { xai_msgs.insert( 0, XAIChatMessage { @@ -342,33 +383,34 @@ impl ChatProvider for XAI { // There's currently no check for these, so we'll leave it up to the user to provide a valid schema. // Unknown if XAI requires these too, but since it copies everything else from OpenAI, it's likely. let response_format: Option = - self.json_schema.as_ref().map(|s| XAIResponseFormat { + self.config.json_schema.as_ref().map(|s| XAIResponseFormat { response_type: XAIResponseType::JsonSchema, json_schema: Some(s.clone()), }); let search_parameters = XaiSearchParameters { - mode: self.xai_search_mode.clone(), + mode: self.config.xai_search_mode.clone(), sources: Some(vec![XaiSearchSource { source_type: self + .config .xai_search_source_type .clone() .unwrap_or("web".to_string()), - excluded_websites: self.xai_search_excluded_websites.clone(), + excluded_websites: self.config.xai_search_excluded_websites.clone(), }]), - max_search_results: self.xai_search_max_results, - from_date: self.xai_search_from_date.clone(), - to_date: self.xai_search_to_date.clone(), + max_search_results: self.config.xai_search_max_results, + from_date: self.config.xai_search_from_date.clone(), + to_date: self.config.xai_search_to_date.clone(), }; let body = XAIChatRequest { - model: &self.model, + model: &self.config.model, messages: xai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: false, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, response_format, search_parameters: Some(&search_parameters), }; @@ -382,10 +424,10 @@ impl ChatProvider for XAI { let mut request = self .client .post("https://api.x.ai/v1/chat/completions") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -432,7 +474,7 @@ impl ChatProvider for XAI { messages: &[ChatMessage], ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".to_string())); } @@ -447,7 +489,7 @@ impl ChatProvider for XAI { }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { xai_msgs.insert( 0, XAIChatMessage { @@ -458,13 +500,13 @@ impl ChatProvider for XAI { } let body = XAIChatRequest { - model: &self.model, + model: &self.config.model, messages: xai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: true, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, response_format: None, search_parameters: None, }; @@ -472,10 +514,10 @@ impl ChatProvider for XAI { let mut request = self .client .post("https://api.x.ai/v1/chat/completions") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -520,26 +562,27 @@ impl CompletionProvider for XAI { #[async_trait] impl EmbeddingProvider for XAI { async fn embed(&self, text: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".into())); } let emb_format = self + .config .embedding_encoding_format .clone() .unwrap_or_else(|| "float".to_string()); let body = XAIEmbeddingRequest { - model: &self.model, + model: &self.config.model, input: text, encoding_format: Some(&emb_format), - dimensions: self.embedding_dimensions, + dimensions: self.config.embedding_dimensions, }; let resp = self .client .post("https://api.x.ai/v1/embeddings") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body) .send() .await? @@ -570,16 +613,16 @@ impl ModelsProvider for XAI { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".to_string())); } let mut request = self .client .get("https://api.x.ai/v1/models") - .bearer_auth(&self.api_key); + .bearer_auth(&self.config.api_key); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index 8aac07e..02792d1 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -22,31 +22,60 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::marker::PhantomData; use std::pin::Pin; +use std::sync::Arc; -/// Generic OpenAI-compatible provider -/// -/// This struct provides a base implementation for any OpenAI-compatible API. -/// Different providers can customize behavior by implementing the `OpenAICompatibleConfig` trait. -pub struct OpenAICompatibleProvider { +/// Configuration for OpenAI-compatible providers. +#[derive(Debug)] +pub struct OpenAICompatibleProviderConfig { + /// API key for authentication. pub api_key: String, + /// Base URL for API requests. pub base_url: Url, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, + /// Available tools for the model to use. pub tools: Option>, + /// Tool choice configuration. pub tool_choice: Option, + /// Reasoning effort level for supported models. pub reasoning_effort: Option, + /// JSON schema for structured output. pub json_schema: Option, + /// Voice setting for TTS. pub voice: Option, + /// Extra body parameters for custom fields. pub extra_body: serde_json::Map, + /// Whether to enable parallel tool calls. pub parallel_tool_calls: bool, + /// Encoding format for embeddings. pub embedding_encoding_format: Option, + /// Dimensions for embeddings. pub embedding_dimensions: Option, + /// Whether to normalize streaming responses. pub normalize_response: bool, +} + +/// Generic OpenAI-compatible provider +/// +/// This struct provides a base implementation for any OpenAI-compatible API. +/// Different providers can customize behavior by implementing the `OpenAICompatibleConfig` trait. +#[derive(Debug, Clone)] +pub struct OpenAICompatibleProvider { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. pub client: Client, _phantom: PhantomData, } @@ -321,11 +350,60 @@ impl OpenAICompatibleProvider { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } + let client = builder.build().expect("Failed to build reqwest Client"); + Self::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + voice, + extra_body, + parallel_tool_calls, + normalize_response, + embedding_encoding_format, + embedding_dimensions, + ) + } + + /// Creates a new provider with a custom HTTP client for connection pooling + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + reasoning_effort: Option, + json_schema: Option, + voice: Option, + extra_body: Option, + parallel_tool_calls: Option, + normalize_response: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + ) -> Self { let extra_body = match extra_body { Some(serde_json::Value::Object(map)) => map, - _ => serde_json::Map::new(), // Should we panic here? + _ => serde_json::Map::new(), }; - Self { + let config = OpenAICompatibleProviderConfig { api_key: api_key.into(), base_url: Url::parse(&format!("{}/", base_url.unwrap_or_else(|| T::DEFAULT_BASE_URL.to_owned()).trim_end_matches("/"))) .expect("Failed to parse base URL"), @@ -346,7 +424,10 @@ impl OpenAICompatibleProvider { normalize_response: normalize_response.unwrap_or(true), embedding_encoding_format, embedding_dimensions, - client: builder.build().expect("Failed to build reqwest Client"), + }; + Self { + config: Arc::new(config), + client, _phantom: PhantomData, } } @@ -372,7 +453,7 @@ impl OpenAICompatibleProvider { } }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { openai_msgs.insert( 0, OpenAIChatMessage { @@ -401,7 +482,7 @@ impl ChatProvider for OpenAICompatibleProvider { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError(format!( "Missing {} API key", T::PROVIDER_NAME @@ -409,47 +490,47 @@ impl ChatProvider for OpenAICompatibleProvider { } let openai_msgs = self.prepare_messages(messages); let response_format: Option = if T::SUPPORTS_STRUCTURED_OUTPUT { - self.json_schema.clone().map(|s| s.into()) + self.config.json_schema.clone().map(|s| s.into()) } else { None }; - let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone()); + let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.config.tools.clone()); let request_tool_choice = if request_tools.is_some() { - self.tool_choice.clone() + self.config.tool_choice.clone() } else { None }; let reasoning_effort = if T::SUPPORTS_REASONING_EFFORT { - self.reasoning_effort.clone() + self.config.reasoning_effort.clone() } else { None }; let parallel_tool_calls = if T::SUPPORTS_PARALLEL_TOOL_CALLS { - Some(self.parallel_tool_calls) + Some(self.config.parallel_tool_calls) } else { None }; let body = OpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: false, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: request_tools, tool_choice: request_tool_choice, reasoning_effort, response_format, stream_options: None, parallel_tool_calls, - extra_body: self.extra_body.clone(), + extra_body: self.config.extra_body.clone(), }; let url = self - .base_url + .config.base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body); + let mut request = self.client.post(url).bearer_auth(&self.config.api_key).json(&body); // Add custom headers if provider specifies them if let Some(headers) = T::custom_headers() { for (key, value) in headers { @@ -461,7 +542,7 @@ impl ChatProvider for OpenAICompatibleProvider { log::trace!("{} request payload: {}", T::PROVIDER_NAME, json); } } - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -524,7 +605,7 @@ impl ChatProvider for OpenAICompatibleProvider { std::pin::Pin> + Send>>, LLMError, > { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError(format!( "Missing {} API key", T::PROVIDER_NAME @@ -532,17 +613,17 @@ impl ChatProvider for OpenAICompatibleProvider { } let openai_msgs = self.prepare_messages(messages); let body = OpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: true, - top_p: self.top_p, - top_k: self.top_k, - tools: self.tools.clone(), - tool_choice: self.tool_choice.clone(), + top_p: self.config.top_p, + top_k: self.config.top_k, + tools: self.config.tools.clone(), + tool_choice: self.config.tool_choice.clone(), reasoning_effort: if T::SUPPORTS_REASONING_EFFORT { - self.reasoning_effort.clone() + self.config.reasoning_effort.clone() } else { None }, @@ -555,17 +636,17 @@ impl ChatProvider for OpenAICompatibleProvider { None }, parallel_tool_calls: if T::SUPPORTS_PARALLEL_TOOL_CALLS { - Some(self.parallel_tool_calls) + Some(self.config.parallel_tool_calls) } else { None }, - extra_body: self.extra_body.clone(), + extra_body: self.config.extra_body.clone(), }; let url = self - .base_url + .config.base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body); + let mut request = self.client.post(url).bearer_auth(&self.config.api_key).json(&body); if let Some(headers) = T::custom_headers() { for (key, value) in headers { request = request.header(key, value); @@ -576,7 +657,7 @@ impl ChatProvider for OpenAICompatibleProvider { log::trace!("{} request payload: {}", T::PROVIDER_NAME, json); } } - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -589,7 +670,7 @@ impl ChatProvider for OpenAICompatibleProvider { raw_response: error_text, }); } - Ok(create_sse_stream(response, self.normalize_response)) + Ok(create_sse_stream(response, self.config.normalize_response)) } /// Sends a streaming chat request with tool support. @@ -612,7 +693,7 @@ impl ChatProvider for OpenAICompatibleProvider { tools: Option<&[Tool]>, ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError(format!( "Missing {} API key", T::PROVIDER_NAME @@ -622,20 +703,20 @@ impl ChatProvider for OpenAICompatibleProvider { let openai_msgs = self.prepare_messages(messages); // Use provided tools or fall back to configured tools - let effective_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone()); + let effective_tools = tools.map(|t| t.to_vec()).or_else(|| self.config.tools.clone()); let body = OpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: true, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: effective_tools, - tool_choice: self.tool_choice.clone(), + tool_choice: self.config.tool_choice.clone(), reasoning_effort: if T::SUPPORTS_REASONING_EFFORT { - self.reasoning_effort.clone() + self.config.reasoning_effort.clone() } else { None }, @@ -648,19 +729,19 @@ impl ChatProvider for OpenAICompatibleProvider { None }, parallel_tool_calls: if T::SUPPORTS_PARALLEL_TOOL_CALLS { - Some(self.parallel_tool_calls) + Some(self.config.parallel_tool_calls) } else { None }, - extra_body: self.extra_body.clone(), + extra_body: self.config.extra_body.clone(), }; let url = self - .base_url + .config.base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body); + let mut request = self.client.post(url).bearer_auth(&self.config.api_key).json(&body); if let Some(headers) = T::custom_headers() { for (key, value) in headers { @@ -678,7 +759,7 @@ impl ChatProvider for OpenAICompatibleProvider { } } - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } From 00197aa37daa3032e20b8febdcbe57a6cfaedd34 Mon Sep 17 00:00:00 2001 From: aniketfuryrocks Date: Mon, 5 Jan 2026 13:12:14 +0530 Subject: [PATCH 2/2] add accessor methods to all backends for backward compatibility --- src/backends/anthropic.rs | 65 ++++++++++++++- src/backends/azure_openai.rs | 72 ++++++++++++++++- src/backends/cohere.rs | 9 ++- src/backends/deepseek.rs | 28 +++++++ src/backends/elevenlabs.rs | 36 ++++++++- src/backends/google.rs | 96 ++++++++++++++++------ src/backends/mistral.rs | 9 ++- src/backends/ollama.rs | 48 +++++++++++ src/backends/openai.rs | 89 +++++++++++++++------ src/backends/phind.rs | 36 +++++++++ src/backends/xai.rs | 48 +++++++++++ src/providers/openai_compatible.rs | 124 ++++++++++++++++++++++++++--- 12 files changed, 587 insertions(+), 73 deletions(-) diff --git a/src/backends/anthropic.rs b/src/backends/anthropic.rs index 80b704d..00482df 100644 --- a/src/backends/anthropic.rs +++ b/src/backends/anthropic.rs @@ -616,6 +616,57 @@ impl Anthropic { } } + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> u32 { + self.config.max_tokens + } + + pub fn temperature(&self) -> f32 { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> u64 { + self.config.timeout_seconds + } + + pub fn system(&self) -> &str { + &self.config.system + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn tool_choice(&self) -> Option<&ToolChoice> { + self.config.tool_choice.as_ref() + } + + pub fn reasoning(&self) -> bool { + self.config.reasoning + } + + pub fn thinking_budget_tokens(&self) -> Option { + self.config.thinking_budget_tokens + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -640,8 +691,11 @@ impl ChatProvider for Anthropic { } let anthropic_messages = Self::convert_messages_to_anthropic(messages); - let (anthropic_tools, final_tool_choice) = - Self::prepare_tools_and_choice(tools, self.config.tools.as_deref(), &self.config.tool_choice); + let (anthropic_tools, final_tool_choice) = Self::prepare_tools_and_choice( + tools, + self.config.tools.as_deref(), + &self.config.tool_choice, + ); let thinking = if self.config.reasoning { Some(ThinkingConfig { @@ -841,8 +895,11 @@ impl ChatProvider for Anthropic { } let anthropic_messages = Self::convert_messages_to_anthropic(messages); - let (anthropic_tools, final_tool_choice) = - Self::prepare_tools_and_choice(tools, self.config.tools.as_deref(), &self.config.tool_choice); + let (anthropic_tools, final_tool_choice) = Self::prepare_tools_and_choice( + tools, + self.config.tools.as_deref(), + &self.config.tool_choice, + ); let req_body = AnthropicCompleteRequest { messages: anthropic_messages, diff --git a/src/backends/azure_openai.rs b/src/backends/azure_openai.rs index 203b4eb..0b6d382 100644 --- a/src/backends/azure_openai.rs +++ b/src/backends/azure_openai.rs @@ -453,6 +453,74 @@ impl AzureOpenAI { client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn api_version(&self) -> &str { + &self.config.api_version + } + + pub fn base_url(&self) -> &Url { + &self.config.base_url + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn tool_choice(&self) -> Option<&ToolChoice> { + self.config.tool_choice.as_ref() + } + + pub fn embedding_encoding_format(&self) -> Option<&str> { + self.config.embedding_encoding_format.as_deref() + } + + pub fn embedding_dimensions(&self) -> Option { + self.config.embedding_dimensions + } + + pub fn reasoning_effort(&self) -> Option<&str> { + self.config.reasoning_effort.as_deref() + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -519,7 +587,9 @@ impl ChatProvider for AzureOpenAI { let response_format: Option = self.config.json_schema.clone().map(|s| s.into()); - let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.config.tools.clone()); + let request_tools = tools + .map(|t| t.to_vec()) + .or_else(|| self.config.tools.clone()); let request_tool_choice = if request_tools.is_some() { self.config.tool_choice.clone() } else { diff --git a/src/backends/cohere.rs b/src/backends/cohere.rs index 4a660e4..a8c11b9 100644 --- a/src/backends/cohere.rs +++ b/src/backends/cohere.rs @@ -143,12 +143,17 @@ impl EmbeddingProvider for Cohere { let body = CohereEmbeddingRequest { model: self.config.model.to_owned(), input, - encoding_format: self.config.embedding_encoding_format.as_deref().map(|s| s.to_owned()), + encoding_format: self + .config + .embedding_encoding_format + .as_deref() + .map(|s| s.to_owned()), dimensions: self.config.embedding_dimensions, }; let url = self - .config.base_url + .config + .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; diff --git a/src/backends/deepseek.rs b/src/backends/deepseek.rs index b9e9a66..824884b 100644 --- a/src/backends/deepseek.rs +++ b/src/backends/deepseek.rs @@ -148,6 +148,34 @@ impl DeepSeek { client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] diff --git a/src/backends/elevenlabs.rs b/src/backends/elevenlabs.rs index 1d39efa..5968859 100644 --- a/src/backends/elevenlabs.rs +++ b/src/backends/elevenlabs.rs @@ -105,7 +105,14 @@ impl ElevenLabs { timeout_seconds: Option, voice: Option, ) -> Self { - Self::with_client(Client::new(), api_key, model_id, base_url, timeout_seconds, voice) + Self::with_client( + Client::new(), + api_key, + model_id, + base_url, + timeout_seconds, + voice, + ) } /// Creates a new ElevenLabs instance with a custom HTTP client. @@ -128,6 +135,30 @@ impl ElevenLabs { client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model_id(&self) -> &str { + &self.config.model_id + } + + pub fn base_url(&self) -> &str { + &self.config.base_url + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn voice(&self) -> Option<&str> { + self.config.voice.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -306,7 +337,8 @@ impl TextToSpeechProvider for ElevenLabs { let url = format!( "{}/text-to-speech/{}?output_format=mp3_44100_128", self.config.base_url, - self.config.voice + self.config + .voice .clone() .unwrap_or("JBFqnCBsd6RMkjVDRZzb".to_string()) ); diff --git a/src/backends/google.rs b/src/backends/google.rs index fb07e21..73fc0df 100644 --- a/src/backends/google.rs +++ b/src/backends/google.rs @@ -558,6 +558,50 @@ impl Google { client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -655,21 +699,21 @@ impl ChatProvider for Google { } else { // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly. - let (response_mime_type, response_schema) = if let Some(json_schema) = &self.config.json_schema - { - if let Some(schema) = &json_schema.schema { - // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it - let mut schema = schema.clone(); - if let Some(obj) = schema.as_object_mut() { - obj.remove("additionalProperties"); + let (response_mime_type, response_schema) = + if let Some(json_schema) = &self.config.json_schema { + if let Some(schema) = &json_schema.schema { + // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it + let mut schema = schema.clone(); + if let Some(obj) = schema.as_object_mut() { + obj.remove("additionalProperties"); + } + (Some(GoogleResponseMimeType::Json), Some(schema)) + } else { + (None, None) } - (Some(GoogleResponseMimeType::Json), Some(schema)) } else { (None, None) - } - } else { - (None, None) - }; + }; Some(GoogleGenerationConfig { max_output_tokens: self.config.max_tokens, temperature: self.config.temperature, @@ -823,23 +867,23 @@ impl ChatProvider for Google { let generation_config = { // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly. - let (response_mime_type, response_schema) = if let Some(json_schema) = &self.config.json_schema - { - if let Some(schema) = &json_schema.schema { - // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it - let mut schema = schema.clone(); - - if let Some(obj) = schema.as_object_mut() { - obj.remove("additionalProperties"); - } + let (response_mime_type, response_schema) = + if let Some(json_schema) = &self.config.json_schema { + if let Some(schema) = &json_schema.schema { + // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it + let mut schema = schema.clone(); + + if let Some(obj) = schema.as_object_mut() { + obj.remove("additionalProperties"); + } - (Some(GoogleResponseMimeType::Json), Some(schema)) + (Some(GoogleResponseMimeType::Json), Some(schema)) + } else { + (None, None) + } } else { (None, None) - } - } else { - (None, None) - }; + }; Some(GoogleGenerationConfig { max_output_tokens: self.config.max_tokens, diff --git a/src/backends/mistral.rs b/src/backends/mistral.rs index 60615b3..862d9fb 100644 --- a/src/backends/mistral.rs +++ b/src/backends/mistral.rs @@ -143,12 +143,17 @@ impl EmbeddingProvider for Mistral { let body = MistralEmbeddingRequest { model: self.config.model.to_owned(), input, - encoding_format: self.config.embedding_encoding_format.as_deref().map(|s| s.to_owned()), + encoding_format: self + .config + .embedding_encoding_format + .as_deref() + .map(|s| s.to_owned()), dimensions: self.config.embedding_dimensions, }; let url = self - .config.base_url + .config + .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; diff --git a/src/backends/ollama.rs b/src/backends/ollama.rs index d83345e..fb38723 100644 --- a/src/backends/ollama.rs +++ b/src/backends/ollama.rs @@ -391,6 +391,54 @@ impl Ollama { } } + pub fn base_url(&self) -> &str { + &self.config.base_url + } + + pub fn api_key(&self) -> Option<&str> { + self.config.api_key.as_deref() + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } + fn make_chat_request<'a>( &'a self, messages: &'a [ChatMessage], diff --git a/src/backends/openai.rs b/src/backends/openai.rs index 420ec41..8205dc6 100644 --- a/src/backends/openai.rs +++ b/src/backends/openai.rs @@ -286,8 +286,13 @@ impl ChatProvider for OpenAI { ) -> Result, LLMError> { // Use the common prepare_messages method from the OpenAI-compatible provider let openai_msgs = self.provider.prepare_messages(messages); - let response_format: Option = - self.provider.config.json_schema.as_ref().cloned().map(|s| s.into()); + let response_format: Option = self + .provider + .config + .json_schema + .as_ref() + .cloned() + .map(|s| s.into()); // Convert regular tools to OpenAI format let tool_calls = tools .map(|t| t.to_vec()) @@ -324,14 +329,20 @@ impl ChatProvider for OpenAI { top_k: self.provider.config.top_k, tools: final_tools, tool_choice: request_tool_choice, - reasoning_effort: self.provider.config.reasoning_effort.as_deref().map(|s| s.to_owned()), + reasoning_effort: self + .provider + .config + .reasoning_effort + .as_deref() + .map(|s| s.to_owned()), response_format, stream_options: None, extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider - .config.base_url + .config + .base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut request = self @@ -436,15 +447,16 @@ impl ChatProvider for OpenAI { > { let openai_msgs = self.provider.prepare_messages(messages); // Convert regular tools to OpenAI format for streaming - let openai_tools: Option> = self.provider.config.tools.as_deref().map(|tools| { - tools - .iter() - .map(|tool| OpenAITool::Function { - tool_type: tool.tool_type.clone(), - function: tool.function.clone(), - }) - .collect() - }); + let openai_tools: Option> = + self.provider.config.tools.as_deref().map(|tools| { + tools + .iter() + .map(|tool| OpenAITool::Function { + tool_type: tool.tool_type.clone(), + function: tool.function.clone(), + }) + .collect() + }); let body = OpenAIAPIChatRequest { model: &self.provider.config.model, messages: openai_msgs, @@ -457,7 +469,12 @@ impl ChatProvider for OpenAI { top_k: self.provider.config.top_k, tools: openai_tools, tool_choice: self.provider.config.tool_choice.as_ref().cloned(), - reasoning_effort: self.provider.config.reasoning_effort.as_deref().map(|s| s.to_owned()), + reasoning_effort: self + .provider + .config + .reasoning_effort + .as_deref() + .map(|s| s.to_owned()), response_format: None, stream_options: Some(OpenAIStreamOptions { include_usage: true, @@ -466,7 +483,8 @@ impl ChatProvider for OpenAI { }; let url = self .provider - .config.base_url + .config + .base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut request = self @@ -523,7 +541,9 @@ impl SpeechToTextProvider for OpenAI { async fn transcribe_file(&self, file_path: &str) -> Result { let url = self - .provider.config.base_url + .provider + .config + .base_url .join("audio/transcriptions") .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -535,7 +555,8 @@ impl SpeechToTextProvider for OpenAI { .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut req = self - .provider.client + .provider + .client .post(url) .bearer_auth(&self.provider.config.api_key) .multipart(form); @@ -566,17 +587,25 @@ impl EmbeddingProvider for OpenAI { let body = OpenAIEmbeddingRequest { model: self.provider.config.model.to_string(), input, - encoding_format: self.provider.config.embedding_encoding_format.as_deref().map(|s| s.to_owned()), + encoding_format: self + .provider + .config + .embedding_encoding_format + .as_deref() + .map(|s| s.to_owned()), dimensions: self.provider.config.embedding_dimensions, }; let url = self - .provider.config.base_url + .provider + .config + .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self - .provider.client + .provider + .client .post(url) .bearer_auth(&self.provider.config.api_key) .json(&body) @@ -597,12 +626,15 @@ impl ModelsProvider for OpenAI { _request: Option<&ModelListRequest>, ) -> Result, LLMError> { let url = self - .provider.config.base_url + .provider + .config + .base_url .join("models") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self - .provider.client + .provider + .client .get(url) .bearer_auth(&self.provider.config.api_key) .send() @@ -619,7 +651,6 @@ impl ModelsProvider for OpenAI { impl LLMProvider for OpenAI {} -// Helper methods to access provider fields impl OpenAI { pub fn api_key(&self) -> &str { &self.provider.config.api_key @@ -637,7 +668,7 @@ impl OpenAI { self.provider.config.timeout_seconds } - pub fn get_client(&self) -> &reqwest::Client { + pub fn client(&self) -> &reqwest::Client { &self.provider.client } @@ -672,7 +703,12 @@ impl OpenAI { top_k: self.provider.config.top_k, tools: Some(hosted_tools), tool_choice: self.provider.config.tool_choice.as_ref().cloned(), - reasoning_effort: self.provider.config.reasoning_effort.as_deref().map(|s| s.to_owned()), + reasoning_effort: self + .provider + .config + .reasoning_effort + .as_deref() + .map(|s| s.to_owned()), response_format: None, // Hosted tools don't use structured output stream_options: None, extra_body: self.provider.config.extra_body.clone(), @@ -680,7 +716,8 @@ impl OpenAI { let url = self .provider - .config.base_url + .config + .base_url .join("responses") // Use responses endpoint for hosted tools .map_err(|e| LLMError::HttpError(e.to_string()))?; diff --git a/src/backends/phind.rs b/src/backends/phind.rs index 8e18708..60ce9bd 100644 --- a/src/backends/phind.rs +++ b/src/backends/phind.rs @@ -131,6 +131,42 @@ impl Phind { } } + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn api_base_url(&self) -> &str { + &self.config.api_base_url + } + + pub fn client(&self) -> &Client { + &self.client + } + /// Creates the required headers for API requests. fn create_headers() -> Result { let mut headers = HeaderMap::new(); diff --git a/src/backends/xai.rs b/src/backends/xai.rs index 5f95631..d0c5118 100644 --- a/src/backends/xai.rs +++ b/src/backends/xai.rs @@ -340,6 +340,54 @@ impl XAI { client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn embedding_encoding_format(&self) -> Option<&str> { + self.config.embedding_encoding_format.as_deref() + } + + pub fn embedding_dimensions(&self) -> Option { + self.config.embedding_dimensions + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index 02792d1..329cb92 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -405,8 +405,13 @@ impl OpenAICompatibleProvider { }; let config = OpenAICompatibleProviderConfig { api_key: api_key.into(), - base_url: Url::parse(&format!("{}/", base_url.unwrap_or_else(|| T::DEFAULT_BASE_URL.to_owned()).trim_end_matches("/"))) - .expect("Failed to parse base URL"), + base_url: Url::parse(&format!( + "{}/", + base_url + .unwrap_or_else(|| T::DEFAULT_BASE_URL.to_owned()) + .trim_end_matches("/") + )) + .expect("Failed to parse base URL"), model: model.unwrap_or_else(|| T::DEFAULT_MODEL.to_string()), max_tokens, temperature, @@ -432,6 +437,86 @@ impl OpenAICompatibleProvider { } } + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn base_url(&self) -> &Url { + &self.config.base_url + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn tool_choice(&self) -> Option<&ToolChoice> { + self.config.tool_choice.as_ref() + } + + pub fn reasoning_effort(&self) -> Option<&str> { + self.config.reasoning_effort.as_deref() + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn voice(&self) -> Option<&str> { + self.config.voice.as_deref() + } + + pub fn extra_body(&self) -> &serde_json::Map { + &self.config.extra_body + } + + pub fn parallel_tool_calls(&self) -> bool { + self.config.parallel_tool_calls + } + + pub fn embedding_encoding_format(&self) -> Option<&str> { + self.config.embedding_encoding_format.as_deref() + } + + pub fn embedding_dimensions(&self) -> Option { + self.config.embedding_dimensions + } + + pub fn normalize_response(&self) -> bool { + self.config.normalize_response + } + + pub fn client(&self) -> &Client { + &self.client + } + pub fn prepare_messages(&self, messages: &[ChatMessage]) -> Vec> { let mut openai_msgs: Vec = messages .iter() @@ -494,7 +579,9 @@ impl ChatProvider for OpenAICompatibleProvider { } else { None }; - let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.config.tools.clone()); + let request_tools = tools + .map(|t| t.to_vec()) + .or_else(|| self.config.tools.clone()); let request_tool_choice = if request_tools.is_some() { self.config.tool_choice.clone() } else { @@ -527,10 +614,15 @@ impl ChatProvider for OpenAICompatibleProvider { extra_body: self.config.extra_body.clone(), }; let url = self - .config.base_url + .config + .base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.config.api_key).json(&body); + let mut request = self + .client + .post(url) + .bearer_auth(&self.config.api_key) + .json(&body); // Add custom headers if provider specifies them if let Some(headers) = T::custom_headers() { for (key, value) in headers { @@ -643,10 +735,15 @@ impl ChatProvider for OpenAICompatibleProvider { extra_body: self.config.extra_body.clone(), }; let url = self - .config.base_url + .config + .base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.config.api_key).json(&body); + let mut request = self + .client + .post(url) + .bearer_auth(&self.config.api_key) + .json(&body); if let Some(headers) = T::custom_headers() { for (key, value) in headers { request = request.header(key, value); @@ -703,7 +800,9 @@ impl ChatProvider for OpenAICompatibleProvider { let openai_msgs = self.prepare_messages(messages); // Use provided tools or fall back to configured tools - let effective_tools = tools.map(|t| t.to_vec()).or_else(|| self.config.tools.clone()); + let effective_tools = tools + .map(|t| t.to_vec()) + .or_else(|| self.config.tools.clone()); let body = OpenAIChatRequest { model: &self.config.model, @@ -737,11 +836,16 @@ impl ChatProvider for OpenAICompatibleProvider { }; let url = self - .config.base_url + .config + .base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.config.api_key).json(&body); + let mut request = self + .client + .post(url) + .bearer_auth(&self.config.api_key) + .json(&body); if let Some(headers) = T::custom_headers() { for (key, value) in headers {