diff --git a/src/backends/anthropic.rs b/src/backends/anthropic.rs index 89fb600..00482df 100644 --- a/src/backends/anthropic.rs +++ b/src/backends/anthropic.rs @@ -3,6 +3,7 @@ //! This module provides integration with Anthropic's Claude models through their API. use std::collections::HashMap; +use std::sync::Arc; use crate::{ builder::LLMBackend, @@ -26,24 +27,51 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Client for interacting with Anthropic's API. +/// Configuration for the Anthropic client. /// -/// Provides methods for chat and completion requests using Anthropic's models. +/// This struct holds all configuration options and is wrapped in an `Arc` +/// to enable cheap cloning of the `Anthropic` client. #[derive(Debug)] -pub struct Anthropic { +pub struct AnthropicConfig { + /// API key for authentication with Anthropic. pub api_key: String, + /// Model identifier (e.g., "claude-3-sonnet-20240229"). pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: u32, + /// Sampling temperature for response randomness. pub temperature: f32, + /// Request timeout in seconds. pub timeout_seconds: u64, + /// System prompt to guide model behavior. pub system: String, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, + /// Available tools for the model to use. pub tools: Option>, + /// Tool choice configuration. pub tool_choice: Option, + /// Whether extended thinking is enabled. pub reasoning: bool, + /// Budget tokens for extended thinking. pub thinking_budget_tokens: Option, - client: Client, +} + +/// Client for interacting with Anthropic's API. +/// +/// Provides methods for chat and completion requests using Anthropic's models. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap +/// (only an atomic reference count increment). This allows sharing a single +/// client across multiple tasks without expensive deep copies. +#[derive(Debug, Clone)] +pub struct Anthropic { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Anthropic-specific tool format that matches their API structure @@ -469,7 +497,6 @@ impl Anthropic { /// * `temperature` - Sampling temperature (defaults to 0.7) /// * `timeout_seconds` - Request timeout in seconds (defaults to 30) /// * `system` - System prompt (defaults to "You are a helpful assistant.") - /// * /// * `thinking_budget_tokens` - Budget tokens for thinking (optional) #[allow(clippy::too_many_arguments)] pub fn new( @@ -486,26 +513,160 @@ impl Anthropic { reasoning: Option, thinking_budget_tokens: Option, ) -> Self { + let timeout = timeout_seconds.unwrap_or(30); let mut builder = Client::builder(); - if let Some(sec) = timeout_seconds { - builder = builder.timeout(std::time::Duration::from_secs(sec)); + if timeout > 0 { + builder = builder.timeout(std::time::Duration::from_secs(timeout)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()), - max_tokens: max_tokens.unwrap_or(300), - temperature: temperature.unwrap_or(0.7), - system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()), - timeout_seconds: timeout_seconds.unwrap_or(30), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, + max_tokens, + temperature, + timeout_seconds, + system, top_p, top_k, tools, tool_choice, - reasoning: reasoning.unwrap_or(false), + reasoning, thinking_budget_tokens, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Anthropic client with a custom HTTP client. + /// + /// This constructor allows sharing a pre-configured `reqwest::Client` across + /// multiple provider instances, enabling connection pooling and custom + /// HTTP settings. + /// + /// # Arguments + /// + /// * `client` - A pre-configured `reqwest::Client` for HTTP requests + /// * `api_key` - Anthropic API key for authentication + /// * `model` - Model identifier (defaults to "claude-3-sonnet-20240229") + /// * `max_tokens` - Maximum tokens in response (defaults to 300) + /// * `temperature` - Sampling temperature (defaults to 0.7) + /// * `timeout_seconds` - Request timeout in seconds (defaults to 30) + /// * `system` - System prompt (defaults to "You are a helpful assistant.") + /// * `thinking_budget_tokens` - Budget tokens for thinking (optional) + /// + /// # Examples + /// + /// ```rust + /// use reqwest::Client; + /// use std::time::Duration; + /// + /// // Create a shared client with custom settings + /// let shared_client = Client::builder() + /// .timeout(Duration::from_secs(120)) + /// .build() + /// .unwrap(); + /// + /// // Use the shared client for multiple Anthropic instances + /// let anthropic = llm::backends::anthropic::Anthropic::with_client( + /// shared_client.clone(), + /// "your-api-key", + /// Some("claude-3-opus-20240229".to_string()), + /// Some(1000), + /// Some(0.7), + /// Some(120), + /// Some("You are a helpful assistant.".to_string()), + /// None, + /// None, + /// None, + /// None, + /// None, + /// None, + /// ); + /// ``` + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + reasoning: Option, + thinking_budget_tokens: Option, + ) -> Self { + Self { + config: Arc::new(AnthropicConfig { + api_key: api_key.into(), + model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()), + max_tokens: max_tokens.unwrap_or(300), + temperature: temperature.unwrap_or(0.7), + system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()), + timeout_seconds: timeout_seconds.unwrap_or(30), + top_p, + top_k, + tools, + tool_choice, + reasoning: reasoning.unwrap_or(false), + thinking_budget_tokens, + }), + client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> u32 { + self.config.max_tokens + } + + pub fn temperature(&self) -> f32 { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> u64 { + self.config.timeout_seconds + } + + pub fn system(&self) -> &str { + &self.config.system + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn tool_choice(&self) -> Option<&ToolChoice> { + self.config.tool_choice.as_ref() + } + + pub fn reasoning(&self) -> bool { + self.config.reasoning + } + + pub fn thinking_budget_tokens(&self) -> Option { + self.config.thinking_budget_tokens + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -525,18 +686,21 @@ impl ChatProvider for Anthropic { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Anthropic API key".to_string())); } let anthropic_messages = Self::convert_messages_to_anthropic(messages); - let (anthropic_tools, final_tool_choice) = - Self::prepare_tools_and_choice(tools, self.tools.as_deref(), &self.tool_choice); + let (anthropic_tools, final_tool_choice) = Self::prepare_tools_and_choice( + tools, + self.config.tools.as_deref(), + &self.config.tool_choice, + ); - let thinking = if self.reasoning { + let thinking = if self.config.reasoning { Some(ThinkingConfig { thinking_type: "enabled".to_string(), - budget_tokens: self.thinking_budget_tokens.unwrap_or(16000), + budget_tokens: self.config.thinking_budget_tokens.unwrap_or(16000), }) } else { None @@ -544,13 +708,13 @@ impl ChatProvider for Anthropic { let req_body = AnthropicCompleteRequest { messages: anthropic_messages, - model: &self.model, - max_tokens: Some(self.max_tokens), - temperature: Some(self.temperature), - system: Some(&self.system), + model: &self.config.model, + max_tokens: Some(self.config.max_tokens), + temperature: Some(self.config.temperature), + system: Some(&self.config.system), stream: Some(false), - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: anthropic_tools, tool_choice: final_tool_choice, thinking, @@ -559,13 +723,13 @@ impl ChatProvider for Anthropic { let mut request = self .client .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .json(&req_body); - if self.timeout_seconds > 0 { - request = request.timeout(std::time::Duration::from_secs(self.timeout_seconds)); + if self.config.timeout_seconds > 0 { + request = request.timeout(std::time::Duration::from_secs(self.config.timeout_seconds)); } if log::log_enabled!(log::Level::Trace) { @@ -613,7 +777,7 @@ impl ChatProvider for Anthropic { messages: &[ChatMessage], ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Anthropic API key".to_string())); } @@ -671,13 +835,13 @@ impl ChatProvider for Anthropic { let req_body = AnthropicCompleteRequest { messages: anthropic_messages, - model: &self.model, - max_tokens: Some(self.max_tokens), - temperature: Some(self.temperature), - system: Some(&self.system), + model: &self.config.model, + max_tokens: Some(self.config.max_tokens), + temperature: Some(self.config.temperature), + system: Some(&self.config.system), stream: Some(true), - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: None, tool_choice: None, thinking: None, @@ -686,13 +850,13 @@ impl ChatProvider for Anthropic { let mut request = self .client .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .json(&req_body); - if self.timeout_seconds > 0 { - request = request.timeout(std::time::Duration::from_secs(self.timeout_seconds)); + if self.config.timeout_seconds > 0 { + request = request.timeout(std::time::Duration::from_secs(self.config.timeout_seconds)); } let response = request.send().await?; @@ -726,23 +890,26 @@ impl ChatProvider for Anthropic { tools: Option<&[Tool]>, ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Anthropic API key".to_string())); } let anthropic_messages = Self::convert_messages_to_anthropic(messages); - let (anthropic_tools, final_tool_choice) = - Self::prepare_tools_and_choice(tools, self.tools.as_deref(), &self.tool_choice); + let (anthropic_tools, final_tool_choice) = Self::prepare_tools_and_choice( + tools, + self.config.tools.as_deref(), + &self.config.tool_choice, + ); let req_body = AnthropicCompleteRequest { messages: anthropic_messages, - model: &self.model, - max_tokens: Some(self.max_tokens), - temperature: Some(self.temperature), - system: Some(&self.system), + model: &self.config.model, + max_tokens: Some(self.config.max_tokens), + temperature: Some(self.config.temperature), + system: Some(&self.config.system), stream: Some(true), - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: anthropic_tools, tool_choice: final_tool_choice, thinking: None, // Thinking not supported with streaming tools @@ -751,13 +918,13 @@ impl ChatProvider for Anthropic { let mut request = self .client .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .json(&req_body); - if self.timeout_seconds > 0 { - request = request.timeout(std::time::Duration::from_secs(self.timeout_seconds)); + if self.config.timeout_seconds > 0 { + request = request.timeout(std::time::Duration::from_secs(self.config.timeout_seconds)); } if log::log_enabled!(log::Level::Trace) { @@ -924,7 +1091,7 @@ impl ModelsProvider for Anthropic { let resp = self .client .get("https://api.anthropic.com/v1/models") - .header("x-api-key", &self.api_key) + .header("x-api-key", &self.config.api_key) .header("Content-Type", "application/json") .header("anthropic-version", "2023-06-01") .send() @@ -938,7 +1105,7 @@ impl ModelsProvider for Anthropic { impl crate::LLMProvider for Anthropic { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } diff --git a/src/backends/azure_openai.rs b/src/backends/azure_openai.rs index 562f755..0b6d382 100644 --- a/src/backends/azure_openai.rs +++ b/src/backends/azure_openai.rs @@ -2,6 +2,8 @@ //! //! This module provides integration with Azure OpenAI's GPT models through their API. +use std::sync::Arc; + #[cfg(feature = "azure_openai")] use crate::{ builder::LLMBackend, @@ -24,29 +26,54 @@ use either::*; use reqwest::{Client, Url}; use serde::{Deserialize, Serialize}; -/// Client for interacting with Azure OpenAI's API. -/// -/// Provides methods for chat and completion requests using Azure OpenAI's models. -pub struct AzureOpenAI { +/// Configuration for the Azure OpenAI client. +#[derive(Debug)] +pub struct AzureOpenAIConfig { + /// API key for authentication. pub api_key: String, + /// API version string. pub api_version: String, + /// Base URL for API requests. pub base_url: Url, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, + /// Available tools for the model to use. pub tools: Option>, + /// Tool choice configuration. pub tool_choice: Option, - /// Embedding parameters + /// Encoding format for embeddings. pub embedding_encoding_format: Option, + /// Dimensions for embeddings. pub embedding_dimensions: Option, + /// Reasoning effort level. pub reasoning_effort: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - client: Client, +} + +/// Client for interacting with Azure OpenAI's API. +/// +/// Provides methods for chat and completion requests using Azure OpenAI's models. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct AzureOpenAI { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Individual message in an OpenAI chat conversation. @@ -356,31 +383,144 @@ impl AzureOpenAI { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - - let endpoint = endpoint.into(); - let deployment_id = deployment_id.into(); - - Self { - api_key: api_key.into(), - api_version: api_version.into(), - base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/")) - .expect("Failed to parse base Url"), - model: model.unwrap_or("gpt-3.5-turbo".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + api_version, + deployment_id, + endpoint, + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, - tools, - tool_choice, embedding_encoding_format, embedding_dimensions, - client: builder.build().expect("Failed to build reqwest Client"), + tools, + tool_choice, reasoning_effort, json_schema, + ) + } + + /// Creates a new Azure OpenAI client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + api_version: impl Into, + deployment_id: impl Into, + endpoint: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + tools: Option>, + tool_choice: Option, + reasoning_effort: Option, + json_schema: Option, + ) -> Self { + let endpoint = endpoint.into(); + let deployment_id = deployment_id.into(); + + Self { + config: Arc::new(AzureOpenAIConfig { + api_key: api_key.into(), + api_version: api_version.into(), + base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/")) + .expect("Failed to parse base Url"), + model: model.unwrap_or("gpt-3.5-turbo".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + tools, + tool_choice, + embedding_encoding_format, + embedding_dimensions, + reasoning_effort, + json_schema, + }), + client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn api_version(&self) -> &str { + &self.config.api_version + } + + pub fn base_url(&self) -> &Url { + &self.config.base_url + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn tool_choice(&self) -> Option<&ToolChoice> { + self.config.tool_choice.as_ref() + } + + pub fn embedding_encoding_format(&self) -> Option<&str> { + self.config.embedding_encoding_format.as_deref() + } + + pub fn embedding_dimensions(&self) -> Option { + self.config.embedding_dimensions + } + + pub fn reasoning_effort(&self) -> Option<&str> { + self.config.reasoning_effort.as_deref() + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -399,7 +539,7 @@ impl ChatProvider for AzureOpenAI { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing Azure OpenAI API key".to_string(), )); @@ -425,7 +565,7 @@ impl ChatProvider for AzureOpenAI { } } - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { openai_msgs.insert( 0, AzureOpenAIChatMessage { @@ -445,26 +585,28 @@ impl ChatProvider for AzureOpenAI { // Build the response format object let response_format: Option = - self.json_schema.clone().map(|s| s.into()); + self.config.json_schema.clone().map(|s| s.into()); - let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone()); + let request_tools = tools + .map(|t| t.to_vec()) + .or_else(|| self.config.tools.clone()); let request_tool_choice = if request_tools.is_some() { - self.tool_choice.clone() + self.config.tool_choice.clone() } else { None }; let body = AzureOpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: false, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: request_tools, tool_choice: request_tool_choice, - reasoning_effort: self.reasoning_effort.clone(), + reasoning_effort: self.config.reasoning_effort.clone(), response_format, }; @@ -475,20 +617,21 @@ impl ChatProvider for AzureOpenAI { } let mut url = self + .config .base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; url.query_pairs_mut() - .append_pair("api-version", &self.api_version); + .append_pair("api-version", &self.config.api_version); let mut request = self .client .post(url) - .header("api-key", &self.api_key) + .header("api-key", &self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -542,34 +685,36 @@ impl CompletionProvider for AzureOpenAI { #[async_trait] impl EmbeddingProvider for AzureOpenAI { async fn embed(&self, input: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing OpenAI API key".into())); } let emb_format = self + .config .embedding_encoding_format .clone() .unwrap_or_else(|| "float".to_string()); let body = OpenAIEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.clone(), input, encoding_format: Some(emb_format), - dimensions: self.embedding_dimensions, + dimensions: self.config.embedding_dimensions, }; let mut url = self + .config .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; url.query_pairs_mut() - .append_pair("api-version", &self.api_version); + .append_pair("api-version", &self.config.api_version); let resp = self .client .post(url) - .header("api-key", &self.api_key) + .header("api-key", &self.config.api_key) .json(&body) .send() .await? @@ -584,7 +729,7 @@ impl EmbeddingProvider for AzureOpenAI { impl LLMProvider for AzureOpenAI { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -612,23 +757,24 @@ impl ModelsProvider for AzureOpenAI { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing Azure OpenAI API key".to_string(), )); } let mut url = self + .config .base_url .join("models") .map_err(|e| LLMError::HttpError(e.to_string()))?; url.query_pairs_mut() - .append_pair("api-version", &self.api_version); + .append_pair("api-version", &self.config.api_version); - let mut request = self.client.get(url).header("api-key", &self.api_key); + let mut request = self.client.get(url).header("api-key", &self.config.api_key); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/backends/cohere.rs b/src/backends/cohere.rs index a835556..a8c11b9 100644 --- a/src/backends/cohere.rs +++ b/src/backends/cohere.rs @@ -104,7 +104,7 @@ struct CohereEmbeddingResponse { impl LLMProvider for Cohere { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -136,18 +136,23 @@ impl SpeechToTextProvider for Cohere { #[async_trait] impl EmbeddingProvider for Cohere { async fn embed(&self, input: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Cohere API key".into())); } let body = CohereEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.to_owned(), input, - encoding_format: self.embedding_encoding_format.clone(), - dimensions: self.embedding_dimensions, + encoding_format: self + .config + .embedding_encoding_format + .as_deref() + .map(|s| s.to_owned()), + dimensions: self.config.embedding_dimensions, }; let url = self + .config .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -155,7 +160,7 @@ impl EmbeddingProvider for Cohere { let resp = self .client .post(url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body) .send() .await? diff --git a/src/backends/deepseek.rs b/src/backends/deepseek.rs index 7ef83bf..824884b 100644 --- a/src/backends/deepseek.rs +++ b/src/backends/deepseek.rs @@ -2,6 +2,8 @@ //! //! This module provides integration with DeepSeek's models through their API. +use std::sync::Arc; + use crate::chat::{ChatResponse, Tool}; #[cfg(feature = "deepseek")] use crate::{ @@ -21,14 +23,32 @@ use serde::{Deserialize, Serialize}; use crate::ToolCall; -pub struct DeepSeek { +/// Configuration for the DeepSeek client. +#[derive(Debug)] +pub struct DeepSeekConfig { + /// API key for authentication with DeepSeek. pub api_key: String, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, - client: Client, +} + +/// Client for interacting with DeepSeek's API. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct DeepSeek { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } #[derive(Serialize)] @@ -95,16 +115,67 @@ impl DeepSeek { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or("deepseek-chat".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, max_tokens, temperature, - system, timeout_seconds, - client: builder.build().expect("Failed to build reqwest Client"), + system, + ) + } + + /// Creates a new DeepSeek client with a custom HTTP client. + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + ) -> Self { + Self { + config: Arc::new(DeepSeekConfig { + api_key: api_key.into(), + model: model.unwrap_or("deepseek-chat".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + }), + client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -119,7 +190,7 @@ impl ChatProvider for DeepSeek { /// /// The provider's response text or an error async fn chat(&self, messages: &[ChatMessage]) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing DeepSeek API key".to_string())); } @@ -134,7 +205,7 @@ impl ChatProvider for DeepSeek { }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { deepseek_msgs.insert( 0, DeepSeekChatMessage { @@ -145,9 +216,9 @@ impl ChatProvider for DeepSeek { } let body = DeepSeekChatRequest { - model: &self.model, + model: &self.config.model, messages: deepseek_msgs, - temperature: self.temperature, + temperature: self.config.temperature, stream: false, }; @@ -160,10 +231,10 @@ impl ChatProvider for DeepSeek { let mut request = self .client .post("https://api.deepseek.com/v1/chat/completions") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -230,14 +301,14 @@ impl ModelsProvider for DeepSeek { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing DeepSeek API key".to_string())); } let resp = self .client .get("https://api.deepseek.com/v1/models") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/elevenlabs.rs b/src/backends/elevenlabs.rs index 07f3bd2..5968859 100644 --- a/src/backends/elevenlabs.rs +++ b/src/backends/elevenlabs.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::chat::{ChatMessage, ChatProvider, ChatResponse, Tool}; use crate::completion::{CompletionProvider, CompletionRequest, CompletionResponse}; use crate::embedding::EmbeddingProvider; @@ -12,23 +14,34 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; +/// Configuration for the ElevenLabs client. +#[derive(Debug)] +/// Configuration for the ElevenLabs client. +pub struct ElevenLabsConfig { + /// API key for authentication. + pub api_key: String, + /// Model identifier. + pub model_id: String, + /// Base URL for API requests. + pub base_url: String, + /// Request timeout in seconds. + pub timeout_seconds: Option, + /// Voice setting for TTS. + pub voice: Option, +} + /// ElevenLabs speech to text backend implementation /// /// This struct provides functionality for speech-to-text transcription using the ElevenLabs API. /// It implements various LLM provider traits but only supports speech-to-text functionality. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] pub struct ElevenLabs { - /// API key for ElevenLabs authentication - api_key: String, - /// Model identifier for speech-to-text - model_id: String, - /// Base URL for API requests - base_url: String, - /// Optional timeout duration in seconds - timeout_seconds: Option, - /// HTTP client for making requests - client: Client, - /// Voice ID to use for speech synthesis - voice: Option, + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Internal representation of a word from ElevenLabs API response @@ -92,15 +105,60 @@ impl ElevenLabs { timeout_seconds: Option, voice: Option, ) -> Self { - Self { + Self::with_client( + Client::new(), api_key, model_id, base_url, timeout_seconds, - client: Client::new(), voice, + ) + } + + /// Creates a new ElevenLabs instance with a custom HTTP client. + pub fn with_client( + client: Client, + api_key: String, + model_id: String, + base_url: String, + timeout_seconds: Option, + voice: Option, + ) -> Self { + Self { + config: Arc::new(ElevenLabsConfig { + api_key, + model_id, + base_url, + timeout_seconds, + voice, + }), + client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model_id(&self) -> &str { + &self.config.model_id + } + + pub fn base_url(&self) -> &str { + &self.config.base_url + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn voice(&self) -> Option<&str> { + self.config.voice.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -116,19 +174,19 @@ impl SpeechToTextProvider for ElevenLabs { /// * `Ok(String)` - Transcribed text /// * `Err(LLMError)` - Error if transcription fails async fn transcribe(&self, audio: Vec) -> Result { - let url = format!("{}/speech-to-text", self.base_url); + let url = format!("{}/speech-to-text", self.config.base_url); let part = reqwest::multipart::Part::bytes(audio).file_name("audio.wav"); let form = reqwest::multipart::Form::new() - .text("model_id", self.model_id.clone()) + .text("model_id", self.config.model_id.clone()) .part("file", part); let mut req = self .client .post(url) - .header("xi-api-key", &self.api_key) + .header("xi-api-key", &self.config.api_key) .multipart(form); - if let Some(t) = self.timeout_seconds { + if let Some(t) = self.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } @@ -169,9 +227,9 @@ impl SpeechToTextProvider for ElevenLabs { /// * `Ok(String)` - Transcribed text /// * `Err(LLMError)` - Error if transcription fails async fn transcribe_file(&self, file_path: &str) -> Result { - let url = format!("{}/speech-to-text", self.base_url); + let url = format!("{}/speech-to-text", self.config.base_url); let form = reqwest::multipart::Form::new() - .text("model_id", self.model_id.clone()) + .text("model_id", self.config.model_id.clone()) .file("file", file_path) .await .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -179,10 +237,10 @@ impl SpeechToTextProvider for ElevenLabs { let mut req = self .client .post(url) - .header("xi-api-key", &self.api_key) + .header("xi-api-key", &self.config.api_key) .multipart(form); - if let Some(t) = self.timeout_seconds { + if let Some(t) = self.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } @@ -278,25 +336,26 @@ impl TextToSpeechProvider for ElevenLabs { async fn speech(&self, text: &str) -> Result, LLMError> { let url = format!( "{}/text-to-speech/{}?output_format=mp3_44100_128", - self.base_url, - self.voice + self.config.base_url, + self.config + .voice .clone() .unwrap_or("JBFqnCBsd6RMkjVDRZzb".to_string()) ); let body = serde_json::json!({ "text": text, - "model_id": self.model_id + "model_id": self.config.model_id }); let mut req = self .client .post(url) - .header("xi-api-key", &self.api_key) + .header("xi-api-key", &self.config.api_key) .header("Content-Type", "application/json") .json(&body); - if let Some(t) = self.timeout_seconds { + if let Some(t) = self.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } diff --git a/src/backends/google.rs b/src/backends/google.rs index 32fbb9d..73fc0df 100644 --- a/src/backends/google.rs +++ b/src/backends/google.rs @@ -40,6 +40,8 @@ //! } //! ``` +use std::sync::Arc; + use crate::{ builder::LLMBackend, chat::{ @@ -62,33 +64,44 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Client for interacting with Google's Gemini API. -/// -/// This struct holds the configuration and state needed to make requests to the Gemini API. -/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits. -pub struct Google { - /// API key for authentication with Google's API +/// Configuration for the Google Gemini client. +#[derive(Debug)] +pub struct GoogleConfig { + /// API key for authentication with Google. pub api_key: String, - /// Model identifier (e.g. "gemini-1.5-flash") + /// Model identifier (e.g., "gemini-pro"). pub model: String, - /// Maximum number of tokens to generate in responses + /// Maximum tokens to generate in responses. pub max_tokens: Option, - /// Sampling temperature between 0.0 and 1.0 + /// Sampling temperature for response randomness. pub temperature: Option, - /// Optional system prompt to set context + /// System prompt to guide model behavior. pub system: Option, - /// Request timeout in seconds + /// Request timeout in seconds. pub timeout_seconds: Option, - /// Top-p sampling parameter + /// Top-p (nucleus) sampling parameter. pub top_p: Option, - /// Top-k sampling parameter + /// Top-k sampling parameter. pub top_k: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - /// Available tools for function calling + /// Available tools for the model to use. pub tools: Option>, - /// HTTP client for making API requests - client: Client, +} + +/// Client for interacting with Google's Gemini API. +/// +/// This struct holds the configuration and state needed to make requests to the Gemini API. +/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap +/// (only an atomic reference count increment). +#[derive(Debug, Clone)] +pub struct Google { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Request body for chat completions @@ -474,7 +487,6 @@ impl Google { /// * `temperature` - Sampling temperature between 0.0 and 1.0 /// * `timeout_seconds` - Request timeout in seconds /// * `system` - System prompt to set context - /// * `stream` - Whether to stream responses /// * `top_p` - Top-p sampling parameter /// * `top_k` - Top-k sampling parameter /// * `json_schema` - JSON schema for structured output @@ -500,20 +512,96 @@ impl Google { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, json_schema, tools, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Google Gemini client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + json_schema: Option, + tools: Option>, + ) -> Self { + Self { + config: Arc::new(GoogleConfig { + api_key: api_key.into(), + model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + json_schema, + tools, + }), + client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -528,14 +616,14 @@ impl ChatProvider for Google { /// /// The model's response text or an error async fn chat(&self, messages: &[ChatMessage]) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let mut chat_contents = Vec::with_capacity(messages.len()); // Add system message if present - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_contents.push(GoogleChatContent { role: "user", parts: vec![GoogleContentPart::Text(system)], @@ -601,36 +689,36 @@ impl ChatProvider for Google { } // Remove generation_config if empty to avoid validation errors - let generation_config = if self.max_tokens.is_none() - && self.temperature.is_none() - && self.top_p.is_none() - && self.top_k.is_none() - && self.json_schema.is_none() + let generation_config = if self.config.max_tokens.is_none() + && self.config.temperature.is_none() + && self.config.top_p.is_none() + && self.config.top_k.is_none() + && self.config.json_schema.is_none() { None } else { // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly. - let (response_mime_type, response_schema) = if let Some(json_schema) = &self.json_schema - { - if let Some(schema) = &json_schema.schema { - // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it - let mut schema = schema.clone(); - if let Some(obj) = schema.as_object_mut() { - obj.remove("additionalProperties"); + let (response_mime_type, response_schema) = + if let Some(json_schema) = &self.config.json_schema { + if let Some(schema) = &json_schema.schema { + // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it + let mut schema = schema.clone(); + if let Some(obj) = schema.as_object_mut() { + obj.remove("additionalProperties"); + } + (Some(GoogleResponseMimeType::Json), Some(schema)) + } else { + (None, None) } - (Some(GoogleResponseMimeType::Json), Some(schema)) } else { (None, None) - } - } else { - (None, None) - }; + }; Some(GoogleGenerationConfig { - max_output_tokens: self.max_tokens, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, + max_output_tokens: self.config.max_tokens, + temperature: self.config.temperature, + top_p: self.config.top_p, + top_k: self.config.top_k, response_mime_type, response_schema, }) @@ -649,12 +737,12 @@ impl ChatProvider for Google { let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}", - model = self.model, - key = self.api_key + model = self.config.model, + key = self.config.api_key ); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -696,14 +784,14 @@ impl ChatProvider for Google { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let mut chat_contents = Vec::with_capacity(messages.len()); // Add system message if present - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_contents.push(GoogleChatContent { role: "user", parts: vec![GoogleContentPart::Text(system)], @@ -779,29 +867,29 @@ impl ChatProvider for Google { let generation_config = { // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly. - let (response_mime_type, response_schema) = if let Some(json_schema) = &self.json_schema - { - if let Some(schema) = &json_schema.schema { - // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it - let mut schema = schema.clone(); - - if let Some(obj) = schema.as_object_mut() { - obj.remove("additionalProperties"); - } + let (response_mime_type, response_schema) = + if let Some(json_schema) = &self.config.json_schema { + if let Some(schema) = &json_schema.schema { + // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it + let mut schema = schema.clone(); + + if let Some(obj) = schema.as_object_mut() { + obj.remove("additionalProperties"); + } - (Some(GoogleResponseMimeType::Json), Some(schema)) + (Some(GoogleResponseMimeType::Json), Some(schema)) + } else { + (None, None) + } } else { (None, None) - } - } else { - (None, None) - }; + }; Some(GoogleGenerationConfig { - max_output_tokens: self.max_tokens, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, + max_output_tokens: self.config.max_tokens, + temperature: self.config.temperature, + top_p: self.config.top_p, + top_k: self.config.top_k, response_mime_type, response_schema, }) @@ -821,14 +909,14 @@ impl ChatProvider for Google { let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}", - model = self.model, - key = self.api_key + model = self.config.model, + key = self.config.api_key ); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -906,11 +994,11 @@ impl ChatProvider for Google { std::pin::Pin> + Send>>, LLMError, > { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let mut chat_contents = Vec::with_capacity(messages.len()); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_contents.push(GoogleChatContent { role: "user", parts: vec![GoogleContentPart::Text(system)], @@ -941,18 +1029,18 @@ impl ChatProvider for Google { }, }); } - let generation_config = if self.max_tokens.is_none() - && self.temperature.is_none() - && self.top_p.is_none() - && self.top_k.is_none() + let generation_config = if self.config.max_tokens.is_none() + && self.config.temperature.is_none() + && self.config.top_p.is_none() + && self.config.top_k.is_none() { None } else { Some(GoogleGenerationConfig { - max_output_tokens: self.max_tokens, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, + max_output_tokens: self.config.max_tokens, + temperature: self.config.temperature, + top_p: self.config.top_p, + top_k: self.config.top_k, response_mime_type: None, response_schema: None, }) @@ -965,12 +1053,12 @@ impl ChatProvider for Google { }; let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse&key={key}", - model = self.model, - key = self.api_key + model = self.config.model, + key = self.config.api_key ); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -1012,7 +1100,7 @@ impl CompletionProvider for Google { #[async_trait] impl EmbeddingProvider for Google { async fn embed(&self, texts: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } @@ -1029,7 +1117,7 @@ impl EmbeddingProvider for Google { let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}", - self.api_key + self.config.api_key ); let resp = self @@ -1058,7 +1146,7 @@ impl SpeechToTextProvider for Google { impl LLMProvider for Google { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -1222,13 +1310,13 @@ impl ModelsProvider for Google { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Google API key".to_string())); } let url = format!( "https://generativelanguage.googleapis.com/v1beta/models?key={}", - self.api_key + self.config.api_key ); let resp = self.client.get(&url).send().await?.error_for_status()?; diff --git a/src/backends/groq.rs b/src/backends/groq.rs index 6a7a97f..8fd04fc 100644 --- a/src/backends/groq.rs +++ b/src/backends/groq.rs @@ -93,7 +93,7 @@ impl Groq { impl LLMProvider for Groq { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -133,7 +133,7 @@ impl ModelsProvider for Groq { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Groq API key".to_string())); } @@ -142,7 +142,7 @@ impl ModelsProvider for Groq { let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/huggingface.rs b/src/backends/huggingface.rs index 233ab85..0faa049 100644 --- a/src/backends/huggingface.rs +++ b/src/backends/huggingface.rs @@ -79,7 +79,7 @@ impl HuggingFace { impl LLMProvider for HuggingFace { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -119,7 +119,7 @@ impl ModelsProvider for HuggingFace { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing HuggingFace API key".to_string(), )); @@ -130,7 +130,7 @@ impl ModelsProvider for HuggingFace { let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/mistral.rs b/src/backends/mistral.rs index 761d915..862d9fb 100644 --- a/src/backends/mistral.rs +++ b/src/backends/mistral.rs @@ -104,7 +104,7 @@ struct MistralEmbeddingResponse { impl LLMProvider for Mistral { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -136,18 +136,23 @@ impl SpeechToTextProvider for Mistral { #[async_trait] impl EmbeddingProvider for Mistral { async fn embed(&self, input: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Mistral API key".into())); } let body = MistralEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.to_owned(), input, - encoding_format: self.embedding_encoding_format.clone(), - dimensions: self.embedding_dimensions, + encoding_format: self + .config + .embedding_encoding_format + .as_deref() + .map(|s| s.to_owned()), + dimensions: self.config.embedding_dimensions, }; let url = self + .config .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -155,7 +160,7 @@ impl EmbeddingProvider for Mistral { let resp = self .client .post(url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body) .send() .await? @@ -173,14 +178,14 @@ impl ModelsProvider for Mistral { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing Mistral API key".to_string())); } let url = format!("{}models", MistralConfig::DEFAULT_BASE_URL); let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/ollama.rs b/src/backends/ollama.rs index cd71fec..fb38723 100644 --- a/src/backends/ollama.rs +++ b/src/backends/ollama.rs @@ -3,6 +3,7 @@ //! This module provides integration with Ollama's local LLM server through its API. use std::pin::Pin; +use std::sync::Arc; use crate::{ builder::LLMBackend, @@ -26,24 +27,44 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Client for interacting with Ollama's API. -/// -/// Provides methods for chat and completion requests using Ollama's models. -pub struct Ollama { +/// Configuration for the Ollama client. +#[derive(Debug)] +pub struct OllamaConfig { + /// Base URL for the Ollama API. pub base_url: String, + /// Optional API key for authentication. pub api_key: Option, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - /// Available tools for function calling + /// Available tools for the model to use. pub tools: Option>, - client: Client, +} + +/// Client for interacting with Ollama's API. +/// +/// Provides methods for chat and completion requests using Ollama's models. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct Ollama { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Request payload for Ollama's chat API endpoint. @@ -299,7 +320,6 @@ impl Ollama { /// * `temperature` - Sampling temperature /// * `timeout_seconds` - Request timeout in seconds /// * `system` - System prompt - /// * `stream` - Whether to stream responses /// * `json_schema` - JSON schema for structured output /// * `tools` - Function tools that the model can use #[allow(clippy::too_many_arguments)] @@ -321,22 +341,104 @@ impl Ollama { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - base_url: base_url.into(), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + base_url, api_key, - model: model.unwrap_or("llama3.1".to_string()), - temperature, + model, max_tokens, + temperature, timeout_seconds, system, top_p, top_k, json_schema, tools, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Ollama client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + base_url: impl Into, + api_key: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + json_schema: Option, + tools: Option>, + ) -> Self { + Self { + config: Arc::new(OllamaConfig { + base_url: base_url.into(), + api_key, + model: model.unwrap_or("llama3.1".to_string()), + temperature, + max_tokens, + timeout_seconds, + system, + top_p, + top_k, + json_schema, + tools, + }), + client, } } + pub fn base_url(&self) -> &str { + &self.config.base_url + } + + pub fn api_key(&self) -> Option<&str> { + self.config.api_key.as_deref() + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn client(&self) -> &Client { + &self.client + } + fn make_chat_request<'a>( &'a self, messages: &'a [ChatMessage], @@ -346,7 +448,7 @@ impl Ollama { let mut chat_messages: Vec = messages.iter().map(OllamaChatMessage::from).collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { chat_messages.insert( 0, OllamaChatMessage { @@ -361,7 +463,7 @@ impl Ollama { let ollama_tools = tools.map(|t| t.iter().map(OllamaTool::from).collect()); // Ollama doesn't require the "name" field in the schema, so we just use the schema itself - let format = if let Some(schema) = &self.json_schema { + let format = if let Some(schema) = &self.config.json_schema { schema.schema.as_ref().map(|schema| OllamaResponseFormat { format: OllamaResponseType::StructuredOutput(schema.clone()), }) @@ -370,12 +472,12 @@ impl Ollama { }; OllamaChatRequest { - model: self.model.clone(), + model: self.config.model.clone(), messages: chat_messages, stream, options: Some(OllamaOptions { - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, }), format, tools: ollama_tools, @@ -390,7 +492,7 @@ impl ChatProvider for Ollama { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } @@ -402,11 +504,11 @@ impl ChatProvider for Ollama { } } - let url = format!("{}/api/chat", self.base_url); + let url = format!("{}/api/chat", self.config.base_url); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -426,10 +528,10 @@ impl ChatProvider for Ollama { ) -> Result> + Send>>, LLMError> { let req_body = self.make_chat_request(messages, None, true); - let url = format!("{}/api/chat", self.base_url); + let url = format!("{}/api/chat", self.config.base_url); let mut request = self.client.post(&url).json(&req_body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -454,13 +556,13 @@ impl CompletionProvider for Ollama { /// /// The completion response containing the generated text or an error async fn complete(&self, req: &CompletionRequest) -> Result { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } - let url = format!("{}/api/generate", self.base_url); + let url = format!("{}/api/generate", self.config.base_url); let req_body = OllamaGenerateRequest { - model: self.model.clone(), + model: self.config.model.clone(), prompt: &req.prompt, raw: true, stream: false, @@ -488,13 +590,13 @@ impl CompletionProvider for Ollama { #[async_trait] impl EmbeddingProvider for Ollama { async fn embed(&self, text: Vec) -> Result>, LLMError> { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } - let url = format!("{}/api/embed", self.base_url); + let url = format!("{}/api/embed", self.config.base_url); let body = OllamaEmbeddingRequest { - model: self.model.clone(), + model: self.config.model.clone(), input: text, }; @@ -582,15 +684,15 @@ impl ModelsProvider for Ollama { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.base_url.is_empty() { + if self.config.base_url.is_empty() { return Err(LLMError::InvalidRequest("Missing base_url".to_string())); } - let url = format!("{}/api/tags", self.base_url); + let url = format!("{}/api/tags", self.config.base_url); let mut request = self.client.get(&url); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -602,7 +704,7 @@ impl ModelsProvider for Ollama { impl crate::LLMProvider for Ollama { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } diff --git a/src/backends/openai.rs b/src/backends/openai.rs index 026a9af..8205dc6 100644 --- a/src/backends/openai.rs +++ b/src/backends/openai.rs @@ -286,12 +286,17 @@ impl ChatProvider for OpenAI { ) -> Result, LLMError> { // Use the common prepare_messages method from the OpenAI-compatible provider let openai_msgs = self.provider.prepare_messages(messages); - let response_format: Option = - self.provider.json_schema.clone().map(|s| s.into()); + let response_format: Option = self + .provider + .config + .json_schema + .as_ref() + .cloned() + .map(|s| s.into()); // Convert regular tools to OpenAI format let tool_calls = tools .map(|t| t.to_vec()) - .or_else(|| self.provider.tools.clone()); + .or_else(|| self.provider.config.tools.as_deref().map(|t| t.to_vec())); let mut openai_tools: Vec = Vec::new(); // Add regular function tools if let Some(tools) = &tool_calls { @@ -308,29 +313,35 @@ impl ChatProvider for OpenAI { Some(openai_tools) }; let request_tool_choice = if final_tools.is_some() { - self.provider.tool_choice.clone() + self.provider.config.tool_choice.as_ref().cloned() } else { None }; let body = OpenAIAPIChatRequest { - model: self.provider.model.as_str(), + model: &self.provider.config.model, messages: openai_msgs, input: None, - max_completion_tokens: self.provider.max_tokens, + max_completion_tokens: self.provider.config.max_tokens, max_output_tokens: None, - temperature: self.provider.temperature, + temperature: self.provider.config.temperature, stream: false, - top_p: self.provider.top_p, - top_k: self.provider.top_k, + top_p: self.provider.config.top_p, + top_k: self.provider.config.top_k, tools: final_tools, tool_choice: request_tool_choice, - reasoning_effort: self.provider.reasoning_effort.clone(), + reasoning_effort: self + .provider + .config + .reasoning_effort + .as_deref() + .map(|s| s.to_owned()), response_format, stream_options: None, - extra_body: self.provider.extra_body.clone(), + extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider + .config .base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -338,14 +349,14 @@ impl ChatProvider for OpenAI { .provider .client .post(url) - .bearer_auth(&self.provider.api_key) + .bearer_auth(&self.provider.config.api_key) .json(&body); if log::log_enabled!(log::Level::Trace) { if let Ok(json) = serde_json::to_string(&body) { log::trace!("OpenAI request payload: {}", json); } } - if let Some(timeout) = self.provider.timeout_seconds { + if let Some(timeout) = self.provider.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -436,36 +447,43 @@ impl ChatProvider for OpenAI { > { let openai_msgs = self.provider.prepare_messages(messages); // Convert regular tools to OpenAI format for streaming - let openai_tools: Option> = self.provider.tools.as_ref().map(|tools| { - tools - .iter() - .map(|tool| OpenAITool::Function { - tool_type: tool.tool_type.clone(), - function: tool.function.clone(), - }) - .collect() - }); + let openai_tools: Option> = + self.provider.config.tools.as_deref().map(|tools| { + tools + .iter() + .map(|tool| OpenAITool::Function { + tool_type: tool.tool_type.clone(), + function: tool.function.clone(), + }) + .collect() + }); let body = OpenAIAPIChatRequest { - model: &self.provider.model, + model: &self.provider.config.model, messages: openai_msgs, input: None, - max_completion_tokens: self.provider.max_tokens, + max_completion_tokens: self.provider.config.max_tokens, max_output_tokens: None, - temperature: self.provider.temperature, + temperature: self.provider.config.temperature, stream: true, - top_p: self.provider.top_p, - top_k: self.provider.top_k, + top_p: self.provider.config.top_p, + top_k: self.provider.config.top_k, tools: openai_tools, - tool_choice: self.provider.tool_choice.clone(), - reasoning_effort: self.provider.reasoning_effort.clone(), + tool_choice: self.provider.config.tool_choice.as_ref().cloned(), + reasoning_effort: self + .provider + .config + .reasoning_effort + .as_deref() + .map(|s| s.to_owned()), response_format: None, stream_options: Some(OpenAIStreamOptions { include_usage: true, }), - extra_body: self.provider.extra_body.clone(), + extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider + .config .base_url .join("chat/completions") .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -473,9 +491,9 @@ impl ChatProvider for OpenAI { .provider .client .post(url) - .bearer_auth(&self.provider.api_key) + .bearer_auth(&self.provider.config.api_key) .json(&body); - if let Some(timeout) = self.provider.timeout_seconds { + if let Some(timeout) = self.provider.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -489,7 +507,7 @@ impl ChatProvider for OpenAI { } Ok(create_sse_stream( response, - self.provider.normalize_response, + self.provider.config.normalize_response, )) } @@ -523,24 +541,27 @@ impl SpeechToTextProvider for OpenAI { async fn transcribe_file(&self, file_path: &str) -> Result { let url = self - .base_url() + .provider + .config + .base_url .join("audio/transcriptions") .map_err(|e| LLMError::HttpError(e.to_string()))?; let form = reqwest::multipart::Form::new() - .text("model", self.model().to_string()) + .text("model", self.provider.config.model.to_string()) .text("response_format", "text") .file("file", file_path) .await .map_err(|e| LLMError::HttpError(e.to_string()))?; let mut req = self - .client() + .provider + .client .post(url) - .bearer_auth(self.api_key()) + .bearer_auth(&self.provider.config.api_key) .multipart(form); - if let Some(t) = self.timeout_seconds() { + if let Some(t) = self.provider.config.timeout_seconds { req = req.timeout(Duration::from_secs(t)); } @@ -564,21 +585,29 @@ impl TextToSpeechProvider for OpenAI { impl EmbeddingProvider for OpenAI { async fn embed(&self, input: Vec) -> Result>, LLMError> { let body = OpenAIEmbeddingRequest { - model: self.model().to_string(), + model: self.provider.config.model.to_string(), input, - encoding_format: self.provider.embedding_encoding_format.clone(), - dimensions: self.provider.embedding_dimensions, + encoding_format: self + .provider + .config + .embedding_encoding_format + .as_deref() + .map(|s| s.to_owned()), + dimensions: self.provider.config.embedding_dimensions, }; let url = self - .base_url() + .provider + .config + .base_url .join("embeddings") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self - .client() + .provider + .client .post(url) - .bearer_auth(self.api_key()) + .bearer_auth(&self.provider.config.api_key) .json(&body) .send() .await? @@ -597,14 +626,17 @@ impl ModelsProvider for OpenAI { _request: Option<&ModelListRequest>, ) -> Result, LLMError> { let url = self - .base_url() + .provider + .config + .base_url .join("models") .map_err(|e| LLMError::HttpError(e.to_string()))?; let resp = self - .client() + .provider + .client .get(url) - .bearer_auth(self.api_key()) + .bearer_auth(&self.provider.config.api_key) .send() .await? .error_for_status()?; @@ -619,22 +651,21 @@ impl ModelsProvider for OpenAI { impl LLMProvider for OpenAI {} -// Helper methods to access provider fields impl OpenAI { pub fn api_key(&self) -> &str { - &self.provider.api_key + &self.provider.config.api_key } pub fn model(&self) -> &str { - &self.provider.model + &self.provider.config.model } pub fn base_url(&self) -> &reqwest::Url { - &self.provider.base_url + &self.provider.config.base_url } pub fn timeout_seconds(&self) -> Option { - self.provider.timeout_seconds + self.provider.config.timeout_seconds } pub fn client(&self) -> &reqwest::Client { @@ -642,7 +673,7 @@ impl OpenAI { } pub fn tools(&self) -> Option<&[Tool]> { - self.provider.tools.as_deref() + self.provider.config.tools.as_deref() } /// Chat with OpenAI-hosted tools using the `/responses` endpoint @@ -661,25 +692,31 @@ impl OpenAI { hosted_tools: Vec, ) -> Result, LLMError> { let body = OpenAIAPIChatRequest { - model: self.provider.model.as_str(), + model: &self.provider.config.model, messages: Vec::new(), // Empty for hosted tools input: Some(input), max_completion_tokens: None, - max_output_tokens: self.provider.max_tokens, - temperature: self.provider.temperature, + max_output_tokens: self.provider.config.max_tokens, + temperature: self.provider.config.temperature, stream: false, - top_p: self.provider.top_p, - top_k: self.provider.top_k, + top_p: self.provider.config.top_p, + top_k: self.provider.config.top_k, tools: Some(hosted_tools), - tool_choice: self.provider.tool_choice.clone(), - reasoning_effort: self.provider.reasoning_effort.clone(), + tool_choice: self.provider.config.tool_choice.as_ref().cloned(), + reasoning_effort: self + .provider + .config + .reasoning_effort + .as_deref() + .map(|s| s.to_owned()), response_format: None, // Hosted tools don't use structured output stream_options: None, - extra_body: self.provider.extra_body.clone(), + extra_body: self.provider.config.extra_body.clone(), }; let url = self .provider + .config .base_url .join("responses") // Use responses endpoint for hosted tools .map_err(|e| LLMError::HttpError(e.to_string()))?; @@ -688,7 +725,7 @@ impl OpenAI { .provider .client .post(url) - .bearer_auth(&self.provider.api_key) + .bearer_auth(&self.provider.config.api_key) .json(&body); if log::log_enabled!(log::Level::Trace) { @@ -697,7 +734,7 @@ impl OpenAI { } } - if let Some(timeout) = self.provider.timeout_seconds { + if let Some(timeout) = self.provider.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/backends/openrouter.rs b/src/backends/openrouter.rs index c0ca170..633400b 100644 --- a/src/backends/openrouter.rs +++ b/src/backends/openrouter.rs @@ -79,7 +79,7 @@ impl OpenRouter { impl LLMProvider for OpenRouter { fn tools(&self) -> Option<&[Tool]> { - self.tools.as_deref() + self.config.tools.as_deref() } } @@ -119,7 +119,7 @@ impl ModelsProvider for OpenRouter { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError( "Missing OpenRouter API key".to_string(), )); @@ -130,7 +130,7 @@ impl ModelsProvider for OpenRouter { let resp = self .client .get(&url) - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .send() .await? .error_for_status()?; diff --git a/src/backends/phind.rs b/src/backends/phind.rs index 4698bc8..60ce9bd 100644 --- a/src/backends/phind.rs +++ b/src/backends/phind.rs @@ -1,5 +1,7 @@ /// Implementation of the Phind LLM provider. /// This module provides integration with Phind's language model API. +use std::sync::Arc; + #[cfg(feature = "phind")] use crate::{ chat::{ChatMessage, ChatProvider, ChatRole}, @@ -21,26 +23,36 @@ use reqwest::StatusCode; use reqwest::{Client, Response}; use serde_json::{json, Value}; -/// Represents a Phind LLM client with configuration options. -pub struct Phind { - /// The model identifier to use (e.g. "Phind-70B") +/// Configuration for the Phind client. +#[derive(Debug)] +pub struct PhindConfig { + /// Model identifier. pub model: String, - /// Maximum number of tokens to generate + /// Maximum tokens to generate in responses. pub max_tokens: Option, - /// Temperature for controlling randomness (0.0-1.0) + /// Sampling temperature for response randomness. pub temperature: Option, - /// System prompt to prepend to conversations + /// System prompt to guide model behavior. pub system: Option, - /// Request timeout in seconds + /// Request timeout in seconds. pub timeout_seconds: Option, - /// Top-p sampling parameter + /// Top-p (nucleus) sampling parameter. pub top_p: Option, - /// Top-k sampling parameter + /// Top-k sampling parameter. pub top_k: Option, - /// Base URL for the Phind API + /// Base URL for API requests. pub api_base_url: String, - /// HTTP client for making requests - client: Client, +} + +/// Represents a Phind LLM client with configuration options. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct Phind { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } #[derive(Debug)] @@ -80,19 +92,81 @@ impl Phind { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - model: model.unwrap_or_else(|| "Phind-70B".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, - api_base_url: "https://https.extension.phind.com/agent/".to_string(), - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new Phind client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + ) -> Self { + Self { + config: Arc::new(PhindConfig { + model: model.unwrap_or_else(|| "Phind-70B".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + api_base_url: "https://https.extension.phind.com/agent/".to_string(), + }), + client, } } + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn api_base_url(&self) -> &str { + &self.config.api_base_url + } + + pub fn client(&self) -> &Client { + &self.client + } + /// Creates the required headers for API requests. fn create_headers() -> Result { let mut headers = HeaderMap::new(); @@ -189,7 +263,7 @@ impl ChatProvider for Phind { })); } - if let Some(system_prompt) = &self.system { + if let Some(system_prompt) = &self.config.system { message_history.insert( 0, json!({ @@ -204,7 +278,7 @@ impl ChatProvider for Phind { "allow_magic_buttons": true, "is_vscode_extension": true, "message_history": message_history, - "requested_model": self.model, + "requested_model": self.config.model, "user_input": messages .iter() .rev() @@ -220,11 +294,11 @@ impl ChatProvider for Phind { let headers = Self::create_headers()?; let mut request = self .client - .post(&self.api_base_url) + .post(&self.config.api_base_url) .headers(headers) .json(&payload); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/backends/xai.rs b/src/backends/xai.rs index 1028829..d0c5118 100644 --- a/src/backends/xai.rs +++ b/src/backends/xai.rs @@ -3,6 +3,8 @@ //! This module provides integration with X.AI's models through their API. //! It implements chat and completion capabilities using the X.AI API endpoints. +use std::sync::Arc; + use crate::ToolCall; #[cfg(feature = "xai")] use crate::{ @@ -23,47 +25,58 @@ use futures::stream::Stream; use reqwest::Client; use serde::{Deserialize, Serialize}; -/// Client for interacting with X.AI's API. -/// -/// This struct provides methods for making chat and completion requests to X.AI's language models. -/// It handles authentication, request configuration, and response parsing. -pub struct XAI { - /// API key for authentication with X.AI services +/// Configuration for the XAI client. +/// Configuration for the X.AI client. +#[derive(Debug)] +pub struct XAIConfig { + /// API key for authentication with X.AI. pub api_key: String, - /// Model identifier to use for requests (e.g. "grok-2-latest") + /// Model identifier. pub model: String, - /// Maximum number of tokens to generate in responses + /// Maximum tokens to generate in responses. pub max_tokens: Option, - /// Temperature parameter for controlling response randomness (0.0 to 1.0) + /// Sampling temperature for response randomness. pub temperature: Option, - /// Optional system prompt to provide context + /// System prompt to guide model behavior. pub system: Option, - /// Request timeout duration in seconds + /// Request timeout in seconds. pub timeout_seconds: Option, - /// Top-p sampling parameter for controlling response diversity + /// Top-p (nucleus) sampling parameter. pub top_p: Option, - /// Top-k sampling parameter for controlling response diversity + /// Top-k sampling parameter. pub top_k: Option, - /// Embedding encoding format + /// Encoding format for embeddings. pub embedding_encoding_format: Option, - /// Embedding dimensions + /// Dimensions for embeddings. pub embedding_dimensions: Option, - /// JSON schema for structured output + /// JSON schema for structured output. pub json_schema: Option, - /// XAI search parameters + /// Search mode for web search functionality. pub xai_search_mode: Option, - /// XAI search sources + /// Source type for search. pub xai_search_source_type: Option, - /// XAI search excluded websites + /// Websites to exclude from search. pub xai_search_excluded_websites: Option>, - /// XAI search max results + /// Maximum number of search results. pub xai_search_max_results: Option, - /// XAI search from date + /// Start date for search results. pub xai_search_from_date: Option, - /// XAI search to date + /// End date for search results. pub xai_search_to_date: Option, - /// HTTP client for making API requests - client: Client, +} + +/// Client for interacting with X.AI's API. +/// +/// This struct provides methods for making chat and completion requests to X.AI's language models. +/// It handles authentication, request configuration, and response parsing. +/// +/// The client uses `Arc` internally for configuration, making cloning cheap. +#[derive(Debug, Clone)] +pub struct XAI { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. + pub client: Client, } /// Search source configuration for search parameters @@ -236,24 +249,6 @@ struct XAIResponseFormat { impl XAI { /// Creates a new X.AI client with the specified configuration. - /// - /// # Arguments - /// - /// * `api_key` - Authentication key for X.AI API access - /// * `model` - Model identifier (defaults to "grok-2-latest" if None) - /// * `max_tokens` - Maximum number of tokens to generate in responses - /// * `temperature` - Sampling temperature for controlling randomness - /// * `timeout_seconds` - Request timeout duration in seconds - /// * `system` - System prompt for providing context - /// * `stream` - Whether to enable streaming responses - /// * `top_p` - Top-p sampling parameter - /// * `top_k` - Top-k sampling parameter - /// * `json_schema` - JSON schema for structured output - /// * `search_parameters` - Search parameters for search functionality - /// - /// # Returns - /// - /// A configured X.AI client instance ready to make API requests. #[allow(clippy::too_many_arguments)] pub fn new( api_key: impl Into, @@ -278,13 +273,14 @@ impl XAI { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } - Self { - api_key: api_key.into(), - model: model.unwrap_or("grok-2-latest".to_string()), + Self::with_client( + builder.build().expect("Failed to build reqwest Client"), + api_key, + model, max_tokens, temperature, - system, timeout_seconds, + system, top_p, top_k, embedding_encoding_format, @@ -296,9 +292,102 @@ impl XAI { xai_search_max_results, xai_search_from_date, xai_search_to_date, - client: builder.build().expect("Failed to build reqwest Client"), + ) + } + + /// Creates a new X.AI client with a custom HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + json_schema: Option, + xai_search_mode: Option, + xai_search_source_type: Option, + xai_search_excluded_websites: Option>, + xai_search_max_results: Option, + xai_search_from_date: Option, + xai_search_to_date: Option, + ) -> Self { + Self { + config: Arc::new(XAIConfig { + api_key: api_key.into(), + model: model.unwrap_or("grok-2-latest".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + embedding_encoding_format, + embedding_dimensions, + json_schema, + xai_search_mode, + xai_search_source_type, + xai_search_excluded_websites, + xai_search_max_results, + xai_search_from_date, + xai_search_to_date, + }), + client, } } + + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn embedding_encoding_format(&self) -> Option<&str> { + self.config.embedding_encoding_format.as_deref() + } + + pub fn embedding_dimensions(&self) -> Option { + self.config.embedding_dimensions + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn client(&self) -> &Client { + &self.client + } } #[async_trait] @@ -313,7 +402,7 @@ impl ChatProvider for XAI { /// /// The generated response text, or an error if the request fails. async fn chat(&self, messages: &[ChatMessage]) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".to_string())); } @@ -328,7 +417,7 @@ impl ChatProvider for XAI { }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { xai_msgs.insert( 0, XAIChatMessage { @@ -342,33 +431,34 @@ impl ChatProvider for XAI { // There's currently no check for these, so we'll leave it up to the user to provide a valid schema. // Unknown if XAI requires these too, but since it copies everything else from OpenAI, it's likely. let response_format: Option = - self.json_schema.as_ref().map(|s| XAIResponseFormat { + self.config.json_schema.as_ref().map(|s| XAIResponseFormat { response_type: XAIResponseType::JsonSchema, json_schema: Some(s.clone()), }); let search_parameters = XaiSearchParameters { - mode: self.xai_search_mode.clone(), + mode: self.config.xai_search_mode.clone(), sources: Some(vec![XaiSearchSource { source_type: self + .config .xai_search_source_type .clone() .unwrap_or("web".to_string()), - excluded_websites: self.xai_search_excluded_websites.clone(), + excluded_websites: self.config.xai_search_excluded_websites.clone(), }]), - max_search_results: self.xai_search_max_results, - from_date: self.xai_search_from_date.clone(), - to_date: self.xai_search_to_date.clone(), + max_search_results: self.config.xai_search_max_results, + from_date: self.config.xai_search_from_date.clone(), + to_date: self.config.xai_search_to_date.clone(), }; let body = XAIChatRequest { - model: &self.model, + model: &self.config.model, messages: xai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: false, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, response_format, search_parameters: Some(&search_parameters), }; @@ -382,10 +472,10 @@ impl ChatProvider for XAI { let mut request = self .client .post("https://api.x.ai/v1/chat/completions") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -432,7 +522,7 @@ impl ChatProvider for XAI { messages: &[ChatMessage], ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".to_string())); } @@ -447,7 +537,7 @@ impl ChatProvider for XAI { }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { xai_msgs.insert( 0, XAIChatMessage { @@ -458,13 +548,13 @@ impl ChatProvider for XAI { } let body = XAIChatRequest { - model: &self.model, + model: &self.config.model, messages: xai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: true, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, response_format: None, search_parameters: None, }; @@ -472,10 +562,10 @@ impl ChatProvider for XAI { let mut request = self .client .post("https://api.x.ai/v1/chat/completions") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } @@ -520,26 +610,27 @@ impl CompletionProvider for XAI { #[async_trait] impl EmbeddingProvider for XAI { async fn embed(&self, text: Vec) -> Result>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".into())); } let emb_format = self + .config .embedding_encoding_format .clone() .unwrap_or_else(|| "float".to_string()); let body = XAIEmbeddingRequest { - model: &self.model, + model: &self.config.model, input: text, encoding_format: Some(&emb_format), - dimensions: self.embedding_dimensions, + dimensions: self.config.embedding_dimensions, }; let resp = self .client .post("https://api.x.ai/v1/embeddings") - .bearer_auth(&self.api_key) + .bearer_auth(&self.config.api_key) .json(&body) .send() .await? @@ -570,16 +661,16 @@ impl ModelsProvider for XAI { &self, _request: Option<&ModelListRequest>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError("Missing X.AI API key".to_string())); } let mut request = self .client .get("https://api.x.ai/v1/models") - .bearer_auth(&self.api_key); + .bearer_auth(&self.config.api_key); - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index 8aac07e..329cb92 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -22,31 +22,60 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::marker::PhantomData; use std::pin::Pin; +use std::sync::Arc; -/// Generic OpenAI-compatible provider -/// -/// This struct provides a base implementation for any OpenAI-compatible API. -/// Different providers can customize behavior by implementing the `OpenAICompatibleConfig` trait. -pub struct OpenAICompatibleProvider { +/// Configuration for OpenAI-compatible providers. +#[derive(Debug)] +pub struct OpenAICompatibleProviderConfig { + /// API key for authentication. pub api_key: String, + /// Base URL for API requests. pub base_url: Url, + /// Model identifier. pub model: String, + /// Maximum tokens to generate in responses. pub max_tokens: Option, + /// Sampling temperature for response randomness. pub temperature: Option, + /// System prompt to guide model behavior. pub system: Option, + /// Request timeout in seconds. pub timeout_seconds: Option, + /// Top-p (nucleus) sampling parameter. pub top_p: Option, + /// Top-k sampling parameter. pub top_k: Option, + /// Available tools for the model to use. pub tools: Option>, + /// Tool choice configuration. pub tool_choice: Option, + /// Reasoning effort level for supported models. pub reasoning_effort: Option, + /// JSON schema for structured output. pub json_schema: Option, + /// Voice setting for TTS. pub voice: Option, + /// Extra body parameters for custom fields. pub extra_body: serde_json::Map, + /// Whether to enable parallel tool calls. pub parallel_tool_calls: bool, + /// Encoding format for embeddings. pub embedding_encoding_format: Option, + /// Dimensions for embeddings. pub embedding_dimensions: Option, + /// Whether to normalize streaming responses. pub normalize_response: bool, +} + +/// Generic OpenAI-compatible provider +/// +/// This struct provides a base implementation for any OpenAI-compatible API. +/// Different providers can customize behavior by implementing the `OpenAICompatibleConfig` trait. +#[derive(Debug, Clone)] +pub struct OpenAICompatibleProvider { + /// Shared configuration wrapped in Arc for cheap cloning. + pub config: Arc, + /// HTTP client for making requests. pub client: Client, _phantom: PhantomData, } @@ -321,14 +350,68 @@ impl OpenAICompatibleProvider { if let Some(sec) = timeout_seconds { builder = builder.timeout(std::time::Duration::from_secs(sec)); } + let client = builder.build().expect("Failed to build reqwest Client"); + Self::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + voice, + extra_body, + parallel_tool_calls, + normalize_response, + embedding_encoding_format, + embedding_dimensions, + ) + } + + /// Creates a new provider with a custom HTTP client for connection pooling + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + reasoning_effort: Option, + json_schema: Option, + voice: Option, + extra_body: Option, + parallel_tool_calls: Option, + normalize_response: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + ) -> Self { let extra_body = match extra_body { Some(serde_json::Value::Object(map)) => map, - _ => serde_json::Map::new(), // Should we panic here? + _ => serde_json::Map::new(), }; - Self { + let config = OpenAICompatibleProviderConfig { api_key: api_key.into(), - base_url: Url::parse(&format!("{}/", base_url.unwrap_or_else(|| T::DEFAULT_BASE_URL.to_owned()).trim_end_matches("/"))) - .expect("Failed to parse base URL"), + base_url: Url::parse(&format!( + "{}/", + base_url + .unwrap_or_else(|| T::DEFAULT_BASE_URL.to_owned()) + .trim_end_matches("/") + )) + .expect("Failed to parse base URL"), model: model.unwrap_or_else(|| T::DEFAULT_MODEL.to_string()), max_tokens, temperature, @@ -346,11 +429,94 @@ impl OpenAICompatibleProvider { normalize_response: normalize_response.unwrap_or(true), embedding_encoding_format, embedding_dimensions, - client: builder.build().expect("Failed to build reqwest Client"), + }; + Self { + config: Arc::new(config), + client, _phantom: PhantomData, } } + pub fn api_key(&self) -> &str { + &self.config.api_key + } + + pub fn base_url(&self) -> &Url { + &self.config.base_url + } + + pub fn model(&self) -> &str { + &self.config.model + } + + pub fn max_tokens(&self) -> Option { + self.config.max_tokens + } + + pub fn temperature(&self) -> Option { + self.config.temperature + } + + pub fn system(&self) -> Option<&str> { + self.config.system.as_deref() + } + + pub fn timeout_seconds(&self) -> Option { + self.config.timeout_seconds + } + + pub fn top_p(&self) -> Option { + self.config.top_p + } + + pub fn top_k(&self) -> Option { + self.config.top_k + } + + pub fn tools(&self) -> Option<&[Tool]> { + self.config.tools.as_deref() + } + + pub fn tool_choice(&self) -> Option<&ToolChoice> { + self.config.tool_choice.as_ref() + } + + pub fn reasoning_effort(&self) -> Option<&str> { + self.config.reasoning_effort.as_deref() + } + + pub fn json_schema(&self) -> Option<&StructuredOutputFormat> { + self.config.json_schema.as_ref() + } + + pub fn voice(&self) -> Option<&str> { + self.config.voice.as_deref() + } + + pub fn extra_body(&self) -> &serde_json::Map { + &self.config.extra_body + } + + pub fn parallel_tool_calls(&self) -> bool { + self.config.parallel_tool_calls + } + + pub fn embedding_encoding_format(&self) -> Option<&str> { + self.config.embedding_encoding_format.as_deref() + } + + pub fn embedding_dimensions(&self) -> Option { + self.config.embedding_dimensions + } + + pub fn normalize_response(&self) -> bool { + self.config.normalize_response + } + + pub fn client(&self) -> &Client { + &self.client + } + pub fn prepare_messages(&self, messages: &[ChatMessage]) -> Vec> { let mut openai_msgs: Vec = messages .iter() @@ -372,7 +538,7 @@ impl OpenAICompatibleProvider { } }) .collect(); - if let Some(system) = &self.system { + if let Some(system) = &self.config.system { openai_msgs.insert( 0, OpenAIChatMessage { @@ -401,7 +567,7 @@ impl ChatProvider for OpenAICompatibleProvider { messages: &[ChatMessage], tools: Option<&[Tool]>, ) -> Result, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError(format!( "Missing {} API key", T::PROVIDER_NAME @@ -409,47 +575,54 @@ impl ChatProvider for OpenAICompatibleProvider { } let openai_msgs = self.prepare_messages(messages); let response_format: Option = if T::SUPPORTS_STRUCTURED_OUTPUT { - self.json_schema.clone().map(|s| s.into()) + self.config.json_schema.clone().map(|s| s.into()) } else { None }; - let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone()); + let request_tools = tools + .map(|t| t.to_vec()) + .or_else(|| self.config.tools.clone()); let request_tool_choice = if request_tools.is_some() { - self.tool_choice.clone() + self.config.tool_choice.clone() } else { None }; let reasoning_effort = if T::SUPPORTS_REASONING_EFFORT { - self.reasoning_effort.clone() + self.config.reasoning_effort.clone() } else { None }; let parallel_tool_calls = if T::SUPPORTS_PARALLEL_TOOL_CALLS { - Some(self.parallel_tool_calls) + Some(self.config.parallel_tool_calls) } else { None }; let body = OpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: false, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: request_tools, tool_choice: request_tool_choice, reasoning_effort, response_format, stream_options: None, parallel_tool_calls, - extra_body: self.extra_body.clone(), + extra_body: self.config.extra_body.clone(), }; let url = self + .config .base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body); + let mut request = self + .client + .post(url) + .bearer_auth(&self.config.api_key) + .json(&body); // Add custom headers if provider specifies them if let Some(headers) = T::custom_headers() { for (key, value) in headers { @@ -461,7 +634,7 @@ impl ChatProvider for OpenAICompatibleProvider { log::trace!("{} request payload: {}", T::PROVIDER_NAME, json); } } - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -524,7 +697,7 @@ impl ChatProvider for OpenAICompatibleProvider { std::pin::Pin> + Send>>, LLMError, > { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError(format!( "Missing {} API key", T::PROVIDER_NAME @@ -532,17 +705,17 @@ impl ChatProvider for OpenAICompatibleProvider { } let openai_msgs = self.prepare_messages(messages); let body = OpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: true, - top_p: self.top_p, - top_k: self.top_k, - tools: self.tools.clone(), - tool_choice: self.tool_choice.clone(), + top_p: self.config.top_p, + top_k: self.config.top_k, + tools: self.config.tools.clone(), + tool_choice: self.config.tool_choice.clone(), reasoning_effort: if T::SUPPORTS_REASONING_EFFORT { - self.reasoning_effort.clone() + self.config.reasoning_effort.clone() } else { None }, @@ -555,17 +728,22 @@ impl ChatProvider for OpenAICompatibleProvider { None }, parallel_tool_calls: if T::SUPPORTS_PARALLEL_TOOL_CALLS { - Some(self.parallel_tool_calls) + Some(self.config.parallel_tool_calls) } else { None }, - extra_body: self.extra_body.clone(), + extra_body: self.config.extra_body.clone(), }; let url = self + .config .base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body); + let mut request = self + .client + .post(url) + .bearer_auth(&self.config.api_key) + .json(&body); if let Some(headers) = T::custom_headers() { for (key, value) in headers { request = request.header(key, value); @@ -576,7 +754,7 @@ impl ChatProvider for OpenAICompatibleProvider { log::trace!("{} request payload: {}", T::PROVIDER_NAME, json); } } - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); } let response = request.send().await?; @@ -589,7 +767,7 @@ impl ChatProvider for OpenAICompatibleProvider { raw_response: error_text, }); } - Ok(create_sse_stream(response, self.normalize_response)) + Ok(create_sse_stream(response, self.config.normalize_response)) } /// Sends a streaming chat request with tool support. @@ -612,7 +790,7 @@ impl ChatProvider for OpenAICompatibleProvider { tools: Option<&[Tool]>, ) -> Result> + Send>>, LLMError> { - if self.api_key.is_empty() { + if self.config.api_key.is_empty() { return Err(LLMError::AuthError(format!( "Missing {} API key", T::PROVIDER_NAME @@ -622,20 +800,22 @@ impl ChatProvider for OpenAICompatibleProvider { let openai_msgs = self.prepare_messages(messages); // Use provided tools or fall back to configured tools - let effective_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone()); + let effective_tools = tools + .map(|t| t.to_vec()) + .or_else(|| self.config.tools.clone()); let body = OpenAIChatRequest { - model: &self.model, + model: &self.config.model, messages: openai_msgs, - max_tokens: self.max_tokens, - temperature: self.temperature, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, stream: true, - top_p: self.top_p, - top_k: self.top_k, + top_p: self.config.top_p, + top_k: self.config.top_k, tools: effective_tools, - tool_choice: self.tool_choice.clone(), + tool_choice: self.config.tool_choice.clone(), reasoning_effort: if T::SUPPORTS_REASONING_EFFORT { - self.reasoning_effort.clone() + self.config.reasoning_effort.clone() } else { None }, @@ -648,19 +828,24 @@ impl ChatProvider for OpenAICompatibleProvider { None }, parallel_tool_calls: if T::SUPPORTS_PARALLEL_TOOL_CALLS { - Some(self.parallel_tool_calls) + Some(self.config.parallel_tool_calls) } else { None }, - extra_body: self.extra_body.clone(), + extra_body: self.config.extra_body.clone(), }; let url = self + .config .base_url .join(T::CHAT_ENDPOINT) .map_err(|e| LLMError::HttpError(e.to_string()))?; - let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body); + let mut request = self + .client + .post(url) + .bearer_auth(&self.config.api_key) + .json(&body); if let Some(headers) = T::custom_headers() { for (key, value) in headers { @@ -678,7 +863,7 @@ impl ChatProvider for OpenAICompatibleProvider { } } - if let Some(timeout) = self.timeout_seconds { + if let Some(timeout) = self.config.timeout_seconds { request = request.timeout(std::time::Duration::from_secs(timeout)); }