From b9915146a91caed902407d0dafebb435e925f93d Mon Sep 17 00:00:00 2001 From: Naz Quadri Date: Wed, 31 Dec 2025 11:49:58 -0500 Subject: [PATCH] feat: add Clone support and HTTP client pooling for providers --- src/backends/anthropic.rs | 116 +++- src/backends/azure_openai.rs | 51 ++ src/backends/cohere.rs | 48 ++ src/backends/deepseek.rs | 25 + src/backends/elevenlabs.rs | 29 +- src/backends/google.rs | 34 ++ src/backends/groq.rs | 51 ++ src/backends/huggingface.rs | 48 ++ src/backends/mistral.rs | 48 ++ src/backends/ollama.rs | 36 ++ src/backends/openai.rs | 71 +++ src/backends/openrouter.rs | 48 ++ src/backends/phind.rs | 29 + src/backends/xai.rs | 48 ++ src/builder.rs | 833 ++++++++++++++++++++--------- src/providers/openai_compatible.rs | 174 ++++++ 16 files changed, 1446 insertions(+), 243 deletions(-) diff --git a/src/backends/anthropic.rs b/src/backends/anthropic.rs index 89fb600..787cbf3 100644 --- a/src/backends/anthropic.rs +++ b/src/backends/anthropic.rs @@ -29,7 +29,7 @@ use serde_json::Value; /// Client for interacting with Anthropic's API. /// /// Provides methods for chat and completion requests using Anthropic's models. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Anthropic { pub api_key: String, pub model: String, @@ -506,6 +506,43 @@ impl Anthropic { client: builder.build().expect("Failed to build reqwest Client"), } } + + /// Creates a new Anthropic client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + reasoning: Option, + thinking_budget_tokens: Option, + ) -> Self { + Self { + api_key: api_key.into(), + model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()), + max_tokens: max_tokens.unwrap_or(300), + temperature: temperature.unwrap_or(0.7), + system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()), + timeout_seconds: timeout_seconds.unwrap_or(30), + top_p, + top_k, + tools, + tool_choice, + reasoning: reasoning.unwrap_or(false), + thinking_budget_tokens, + client, + } + } } #[async_trait] @@ -1391,4 +1428,81 @@ data: {"type": "ping"} let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); assert!(result.is_none()); } + + #[test] + fn test_anthropic_clone() { + let anthropic = Anthropic::new( + "test-api-key", + Some("claude-3-sonnet".to_string()), + Some(1000), + Some(0.7), + Some(30), + Some("You are helpful.".to_string()), + None, + None, + None, + None, + None, + None, + ); + + // Clone the provider + let cloned = anthropic.clone(); + + // Verify both have the same configuration + assert_eq!(anthropic.api_key, cloned.api_key); + assert_eq!(anthropic.model, cloned.model); + assert_eq!(anthropic.max_tokens, cloned.max_tokens); + assert_eq!(anthropic.temperature, cloned.temperature); + assert_eq!(anthropic.system, cloned.system); + } + + #[test] + fn test_anthropic_with_client() { + let shared_client = Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .expect("Failed to build client"); + + let anthropic = Anthropic::with_client( + shared_client.clone(), + "test-api-key", + Some("claude-3-sonnet".to_string()), + Some(1000), + Some(0.7), + Some(30), + Some("You are helpful.".to_string()), + None, + None, + None, + None, + None, + None, + ); + + // Verify configuration + assert_eq!(anthropic.api_key, "test-api-key"); + assert_eq!(anthropic.model, "claude-3-sonnet"); + assert_eq!(anthropic.max_tokens, 1000); + + // Create another provider with the same client + let anthropic2 = Anthropic::with_client( + shared_client, + "test-api-key-2", + Some("claude-3-haiku".to_string()), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ); + + assert_eq!(anthropic2.api_key, "test-api-key-2"); + assert_eq!(anthropic2.model, "claude-3-haiku"); + } } diff --git a/src/backends/azure_openai.rs b/src/backends/azure_openai.rs index 562f755..a813a56 100644 --- a/src/backends/azure_openai.rs +++ b/src/backends/azure_openai.rs @@ -27,6 +27,7 @@ use serde::{Deserialize, Serialize}; /// Client for interacting with Azure OpenAI's API. /// /// Provides methods for chat and completion requests using Azure OpenAI's models. +#[derive(Clone)] pub struct AzureOpenAI { pub api_key: String, pub api_version: String, @@ -381,6 +382,56 @@ impl AzureOpenAI { json_schema, } } + + /// Creates a new Azure OpenAI client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + api_version: impl Into, + deployment_id: impl Into, + endpoint: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + tools: Option>, + tool_choice: Option, + reasoning_effort: Option, + json_schema: Option, + ) -> Self { + let endpoint = endpoint.into(); + let deployment_id = deployment_id.into(); + + Self { + api_key: api_key.into(), + api_version: api_version.into(), + base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/")) + .expect("Failed to parse base Url"), + model: model.unwrap_or("gpt-3.5-turbo".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + tools, + tool_choice, + embedding_encoding_format, + embedding_dimensions, + client, + reasoning_effort, + json_schema, + } + } } #[async_trait] diff --git a/src/backends/cohere.rs b/src/backends/cohere.rs index a835556..b4b5ce9 100644 --- a/src/backends/cohere.rs +++ b/src/backends/cohere.rs @@ -17,6 +17,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; /// Cohere configuration for the generic provider +#[derive(Clone)] pub struct CohereConfig; impl OpenAIProviderConfig for CohereConfig { @@ -78,6 +79,53 @@ impl Cohere { embedding_dimensions, ) } + + /// Creates a new Cohere client with a pre-configured HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_config_and_client( + client: reqwest::Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + extra_body: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + reasoning_effort: Option, + json_schema: Option, + parallel_tool_calls: Option, + normalize_response: Option, + ) -> Self { + >::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + None, + extra_body, + parallel_tool_calls, + normalize_response, + embedding_encoding_format, + embedding_dimensions, + ) + } } // Cohere-specific implementations that don't fit in the generic OpenAI-compatible provider diff --git a/src/backends/deepseek.rs b/src/backends/deepseek.rs index 7ef83bf..42848bc 100644 --- a/src/backends/deepseek.rs +++ b/src/backends/deepseek.rs @@ -21,6 +21,7 @@ use serde::{Deserialize, Serialize}; use crate::ToolCall; +#[derive(Clone)] pub struct DeepSeek { pub api_key: String, pub model: String, @@ -105,6 +106,30 @@ impl DeepSeek { client: builder.build().expect("Failed to build reqwest Client"), } } + + /// Creates a new DeepSeek client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + ) -> Self { + Self { + api_key: api_key.into(), + model: model.unwrap_or("deepseek-chat".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + client, + } + } } #[async_trait] diff --git a/src/backends/elevenlabs.rs b/src/backends/elevenlabs.rs index 07f3bd2..f55bead 100644 --- a/src/backends/elevenlabs.rs +++ b/src/backends/elevenlabs.rs @@ -16,6 +16,7 @@ use std::time::Duration; /// /// This struct provides functionality for speech-to-text transcription using the ElevenLabs API. /// It implements various LLM provider traits but only supports speech-to-text functionality. +#[derive(Clone)] pub struct ElevenLabs { /// API key for ElevenLabs authentication api_key: String, @@ -91,13 +92,39 @@ impl ElevenLabs { base_url: String, timeout_seconds: Option, voice: Option, + ) -> Self { + let mut builder = Client::builder(); + if let Some(sec) = timeout_seconds { + builder = builder.timeout(Duration::from_secs(sec)); + } + Self { + api_key, + model_id, + base_url, + timeout_seconds, + client: builder.build().expect("Failed to build reqwest Client"), + voice, + } + } + + /// Creates a new ElevenLabs client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + pub fn with_client( + client: Client, + api_key: String, + model_id: String, + base_url: String, + timeout_seconds: Option, + voice: Option, ) -> Self { Self { api_key, model_id, base_url, timeout_seconds, - client: Client::new(), + client, voice, } } diff --git a/src/backends/google.rs b/src/backends/google.rs index 32fbb9d..2d5330a 100644 --- a/src/backends/google.rs +++ b/src/backends/google.rs @@ -66,6 +66,7 @@ use serde_json::Value; /// /// This struct holds the configuration and state needed to make requests to the Gemini API. /// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits. +#[derive(Clone)] pub struct Google { /// API key for authentication with Google's API pub api_key: String, @@ -514,6 +515,39 @@ impl Google { client: builder.build().expect("Failed to build reqwest Client"), } } + + /// Creates a new Google client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + json_schema: Option, + tools: Option>, + ) -> Self { + Self { + api_key: api_key.into(), + model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + json_schema, + tools, + client, + } + } } #[async_trait] diff --git a/src/backends/groq.rs b/src/backends/groq.rs index 6a7a97f..d4da244 100644 --- a/src/backends/groq.rs +++ b/src/backends/groq.rs @@ -17,6 +17,7 @@ use crate::{ use async_trait::async_trait; /// Groq configuration for the generic provider +#[derive(Clone)] pub struct GroqConfig; impl OpenAIProviderConfig for GroqConfig { @@ -89,6 +90,56 @@ impl Groq { None, // embedding_dimensions - not supported by Groq ) } + + /// Creates a new Groq client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_config_and_client( + client: reqwest::Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + extra_body: Option, + _embedding_encoding_format: Option, + _embedding_dimensions: Option, + reasoning_effort: Option, + json_schema: Option, + parallel_tool_calls: Option, + normalize_response: Option, + ) -> Self { + OpenAICompatibleProvider::::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + None, // voice - not supported by Groq + extra_body, + parallel_tool_calls, + normalize_response, + None, // embedding_encoding_format - not supported by Groq + None, // embedding_dimensions - not supported by Groq + ) + } } impl LLMProvider for Groq { diff --git a/src/backends/huggingface.rs b/src/backends/huggingface.rs index 233ab85..731e3eb 100644 --- a/src/backends/huggingface.rs +++ b/src/backends/huggingface.rs @@ -17,6 +17,7 @@ use crate::{ use async_trait::async_trait; /// HuggingFace configuration for the generic provider +#[derive(Clone)] pub struct HuggingFaceConfig; impl OpenAIProviderConfig for HuggingFaceConfig { @@ -75,6 +76,53 @@ impl HuggingFace { None, // embedding_dimensions ) } + + /// Creates a new HuggingFace client with a pre-configured HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_config_and_client( + client: reqwest::Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + extra_body: Option, + _embedding_encoding_format: Option, + _embedding_dimensions: Option, + reasoning_effort: Option, + json_schema: Option, + parallel_tool_calls: Option, + normalize_response: Option, + ) -> Self { + OpenAICompatibleProvider::::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + None, + extra_body, + parallel_tool_calls, + normalize_response, + None, + None, + ) + } } impl LLMProvider for HuggingFace { diff --git a/src/backends/mistral.rs b/src/backends/mistral.rs index 761d915..6201025 100644 --- a/src/backends/mistral.rs +++ b/src/backends/mistral.rs @@ -19,6 +19,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; /// Mistral configuration for the generic provider +#[derive(Clone)] pub struct MistralConfig; impl OpenAIProviderConfig for MistralConfig { @@ -78,6 +79,53 @@ impl Mistral { embedding_dimensions, ) } + + /// Creates a new Mistral client with a pre-configured HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_config_and_client( + client: reqwest::Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + extra_body: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + reasoning_effort: Option, + json_schema: Option, + parallel_tool_calls: Option, + normalize_response: Option, + ) -> Self { + >::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + None, + extra_body, + parallel_tool_calls, + normalize_response, + embedding_encoding_format, + embedding_dimensions, + ) + } } // Mistral-specific implementations that don't fit in the generic OpenAI-compatible provider diff --git a/src/backends/ollama.rs b/src/backends/ollama.rs index cd71fec..d822b59 100644 --- a/src/backends/ollama.rs +++ b/src/backends/ollama.rs @@ -29,6 +29,7 @@ use serde_json::Value; /// Client for interacting with Ollama's API. /// /// Provides methods for chat and completion requests using Ollama's models. +#[derive(Clone)] pub struct Ollama { pub base_url: String, pub api_key: Option, @@ -337,6 +338,41 @@ impl Ollama { } } + /// Creates a new Ollama client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + base_url: impl Into, + api_key: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + json_schema: Option, + tools: Option>, + ) -> Self { + Self { + base_url: base_url.into(), + api_key, + model: model.unwrap_or("llama3.1".to_string()), + temperature, + max_tokens, + timeout_seconds, + system, + top_p, + top_k, + json_schema, + tools, + client, + } + } + fn make_chat_request<'a>( &'a self, messages: &'a [ChatMessage], diff --git a/src/backends/openai.rs b/src/backends/openai.rs index 026a9af..6517b59 100644 --- a/src/backends/openai.rs +++ b/src/backends/openai.rs @@ -23,10 +23,12 @@ use crate::{ }; use async_trait::async_trait; use futures::{Stream, StreamExt}; +use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; /// OpenAI configuration for the generic provider +#[derive(Clone)] struct OpenAIConfig; impl OpenAIProviderConfig for OpenAIConfig { @@ -42,6 +44,7 @@ impl OpenAIProviderConfig for OpenAIConfig { // NOTE: OpenAI cannot directly use the OpenAICompatibleProvider type alias, as it needs specific fields /// Client for OpenAI API +#[derive(Clone)] pub struct OpenAI { // Delegate to the generic provider for common functionality provider: OpenAICompatibleProvider, @@ -251,6 +254,74 @@ impl OpenAI { web_search_user_location_approximate_region, }) } + + /// Creates a new OpenAI client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + tools: Option>, + tool_choice: Option, + normalize_response: Option, + reasoning_effort: Option, + json_schema: Option, + voice: Option, + extra_body: Option, + enable_web_search: Option, + web_search_context_size: Option, + web_search_user_location_type: Option, + web_search_user_location_approximate_country: Option, + web_search_user_location_approximate_city: Option, + web_search_user_location_approximate_region: Option, + ) -> Result { + let api_key_str = api_key.into(); + if api_key_str.is_empty() { + return Err(LLMError::AuthError("Missing OpenAI API key".to_string())); + } + Ok(OpenAI { + provider: >::with_client( + client, + api_key_str, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + voice, + extra_body, + None, // parallel_tool_calls + normalize_response, + embedding_encoding_format, + embedding_dimensions, + ), + enable_web_search: enable_web_search.unwrap_or(false), + web_search_context_size, + web_search_user_location_type, + web_search_user_location_approximate_country, + web_search_user_location_approximate_city, + web_search_user_location_approximate_region, + }) + } } // OpenAI-specific implementations that don't fit in the generic provider diff --git a/src/backends/openrouter.rs b/src/backends/openrouter.rs index c0ca170..c9a8c28 100644 --- a/src/backends/openrouter.rs +++ b/src/backends/openrouter.rs @@ -17,6 +17,7 @@ use crate::{ use async_trait::async_trait; /// OpenRouter configuration for the generic provider +#[derive(Clone)] pub struct OpenRouterConfig; impl OpenAIProviderConfig for OpenRouterConfig { @@ -75,6 +76,53 @@ impl OpenRouter { None, // embedding_dimensions - not supported by OpenRouter ) } + + /// Creates a new OpenRouter client with a pre-configured HTTP client. + #[allow(clippy::too_many_arguments)] + pub fn with_config_and_client( + client: reqwest::Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + extra_body: Option, + _embedding_encoding_format: Option, + _embedding_dimensions: Option, + reasoning_effort: Option, + json_schema: Option, + parallel_tool_calls: Option, + normalize_response: Option, + ) -> Self { + OpenAICompatibleProvider::::with_client( + client, + api_key, + base_url, + model, + max_tokens, + temperature, + timeout_seconds, + system, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + None, + extra_body, + parallel_tool_calls, + normalize_response, + None, + None, + ) + } } impl LLMProvider for OpenRouter { diff --git a/src/backends/phind.rs b/src/backends/phind.rs index 4698bc8..561de3a 100644 --- a/src/backends/phind.rs +++ b/src/backends/phind.rs @@ -22,6 +22,7 @@ use reqwest::{Client, Response}; use serde_json::{json, Value}; /// Represents a Phind LLM client with configuration options. +#[derive(Clone)] pub struct Phind { /// The model identifier to use (e.g. "Phind-70B") pub model: String, @@ -93,6 +94,34 @@ impl Phind { } } + /// Creates a new Phind client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + ) -> Self { + Self { + model: model.unwrap_or_else(|| "Phind-70B".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + api_base_url: "https://https.extension.phind.com/agent/".to_string(), + client, + } + } + /// Creates the required headers for API requests. fn create_headers() -> Result { let mut headers = HeaderMap::new(); diff --git a/src/backends/xai.rs b/src/backends/xai.rs index 1028829..4e49a20 100644 --- a/src/backends/xai.rs +++ b/src/backends/xai.rs @@ -27,6 +27,7 @@ use serde::{Deserialize, Serialize}; /// /// This struct provides methods for making chat and completion requests to X.AI's language models. /// It handles authentication, request configuration, and response parsing. +#[derive(Clone)] pub struct XAI { /// API key for authentication with X.AI services pub api_key: String, @@ -299,6 +300,53 @@ impl XAI { client: builder.build().expect("Failed to build reqwest Client"), } } + + /// Creates a new X.AI client with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + json_schema: Option, + xai_search_mode: Option, + xai_search_source_type: Option, + xai_search_excluded_websites: Option>, + xai_search_max_results: Option, + xai_search_from_date: Option, + xai_search_to_date: Option, + ) -> Self { + Self { + api_key: api_key.into(), + model: model.unwrap_or("grok-2-latest".to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + embedding_encoding_format, + embedding_dimensions, + json_schema, + xai_search_mode, + xai_search_source_type, + xai_search_excluded_websites, + xai_search_max_results, + xai_search_from_date, + xai_search_to_date, + client, + } + } } #[async_trait] diff --git a/src/builder.rs b/src/builder.rs index 92033b1..e1a7384 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -198,6 +198,8 @@ pub struct LLMBuilder { resilient_max_delay_ms: Option, /// Resilience: jitter toggle resilient_jitter: Option, + /// Pre-configured HTTP client for connection pooling + client: Option, } impl LLMBuilder { @@ -476,6 +478,36 @@ impl LLMBuilder { self } + /// Sets a pre-configured HTTP client for connection pooling. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + /// + /// # Arguments + /// + /// * `client` - A pre-configured `reqwest::Client` to use for HTTP requests + /// + /// # Examples + /// + /// ```rust + /// use llm::builder::{LLMBuilder, LLMBackend}; + /// use reqwest::Client; + /// use std::time::Duration; + /// + /// let shared_client = Client::builder() + /// .timeout(Duration::from_secs(60)) + /// .build() + /// .unwrap(); + /// + /// let builder = LLMBuilder::new() + /// .backend(LLMBackend::OpenAI) + /// .client(shared_client); + /// ``` + pub fn client(mut self, client: reqwest::Client) -> Self { + self.client = Some(client); + self + } + #[deprecated(note = "Renamed to `xai_search_mode`.")] pub fn search_mode(self, mode: impl Into) -> Self { self.xai_search_mode(mode) @@ -658,32 +690,63 @@ impl LLMBuilder { let key = self.api_key.ok_or_else(|| { LLMError::InvalidRequest("No API key provided for OpenAI".to_string()) })?; - Box::new(crate::backends::openai::OpenAI::new( - key, - self.base_url, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.embedding_encoding_format, - self.embedding_dimensions, - tools, - tool_choice, - self.normalize_response, - self.reasoning_effort, - self.json_schema, - self.voice, - self.extra_body, - self.openai_enable_web_search, - self.openai_web_search_context_size, - self.openai_web_search_user_location_type, - self.openai_web_search_user_location_approximate_country, - self.openai_web_search_user_location_approximate_city, - self.openai_web_search_user_location_approximate_region, - )?) + let openai = if let Some(client) = self.client.clone() { + crate::backends::openai::OpenAI::with_client( + client, + key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + tools.clone(), + tool_choice.clone(), + self.normalize_response, + self.reasoning_effort.clone(), + self.json_schema.clone(), + self.voice.clone(), + self.extra_body.clone(), + self.openai_enable_web_search, + self.openai_web_search_context_size.clone(), + self.openai_web_search_user_location_type.clone(), + self.openai_web_search_user_location_approximate_country.clone(), + self.openai_web_search_user_location_approximate_city.clone(), + self.openai_web_search_user_location_approximate_region.clone(), + )? + } else { + crate::backends::openai::OpenAI::new( + key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + tools.clone(), + tool_choice.clone(), + self.normalize_response, + self.reasoning_effort.clone(), + self.json_schema.clone(), + self.voice.clone(), + self.extra_body.clone(), + self.openai_enable_web_search, + self.openai_web_search_context_size.clone(), + self.openai_web_search_user_location_type.clone(), + self.openai_web_search_user_location_approximate_country.clone(), + self.openai_web_search_user_location_approximate_city.clone(), + self.openai_web_search_user_location_approximate_region.clone(), + )? + }; + Box::new(openai) } } LLMBackend::ElevenLabs => { @@ -693,16 +756,28 @@ impl LLMBuilder { )); #[cfg(feature = "elevenlabs")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for ElevenLabs".to_string()) })?; - let elevenlabs = crate::backends::elevenlabs::ElevenLabs::new( - api_key, - self.model.unwrap_or("eleven_multilingual_v2".to_string()), - "https://api.elevenlabs.io/v1".to_string(), - self.timeout_seconds, - self.voice, - ); + let model = self.model.clone().unwrap_or("eleven_multilingual_v2".to_string()); + let elevenlabs = if let Some(client) = self.client.clone() { + crate::backends::elevenlabs::ElevenLabs::with_client( + client, + api_key, + model, + "https://api.elevenlabs.io/v1".to_string(), + self.timeout_seconds, + self.voice.clone(), + ) + } else { + crate::backends::elevenlabs::ElevenLabs::new( + api_key, + model, + "https://api.elevenlabs.io/v1".to_string(), + self.timeout_seconds, + self.voice.clone(), + ) + }; Box::new(elevenlabs) } } @@ -716,20 +791,38 @@ impl LLMBuilder { let api_key = self.api_key.ok_or_else(|| { LLMError::InvalidRequest("No API key provided for Anthropic".to_string()) })?; - let anthro = crate::backends::anthropic::Anthropic::new( - api_key, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - tools, - self.tool_choice, - self.reasoning, - self.reasoning_budget_tokens, - ); + let anthro = if let Some(client) = self.client.clone() { + crate::backends::anthropic::Anthropic::with_client( + client, + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + tools.clone(), + self.tool_choice.clone(), + self.reasoning, + self.reasoning_budget_tokens, + ) + } else { + crate::backends::anthropic::Anthropic::new( + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + tools.clone(), + self.tool_choice.clone(), + self.reasoning, + self.reasoning_budget_tokens, + ) + }; Box::new(anthro) } } @@ -742,20 +835,38 @@ impl LLMBuilder { { let url = self .base_url + .clone() .unwrap_or("http://localhost:11434".to_string()); - let ollama = crate::backends::ollama::Ollama::new( - url, - self.api_key, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.json_schema, - tools, - ); + let ollama = if let Some(client) = self.client.clone() { + crate::backends::ollama::Ollama::with_client( + client, + url, + self.api_key.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.json_schema.clone(), + tools.clone(), + ) + } else { + crate::backends::ollama::Ollama::new( + url, + self.api_key.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.json_schema.clone(), + tools.clone(), + ) + }; Box::new(ollama) } } @@ -767,17 +878,29 @@ impl LLMBuilder { #[cfg(feature = "deepseek")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for DeepSeek".to_string()) })?; - let deepseek = crate::backends::deepseek::DeepSeek::new( - api_key, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - ); + let deepseek = if let Some(client) = self.client.clone() { + crate::backends::deepseek::DeepSeek::with_client( + client, + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + ) + } else { + crate::backends::deepseek::DeepSeek::new( + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + ) + }; Box::new(deepseek) } } @@ -789,29 +912,52 @@ impl LLMBuilder { #[cfg(feature = "xai")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for XAI".to_string()) })?; - let xai = crate::backends::xai::XAI::new( - api_key, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.embedding_encoding_format, - self.embedding_dimensions, - self.json_schema, - self.xai_search_mode, - self.xai_search_source_type, - self.xai_search_excluded_websites, - self.xai_search_max_results, - self.xai_search_from_date, - self.xai_search_to_date, - ); + let xai = if let Some(client) = self.client.clone() { + crate::backends::xai::XAI::with_client( + client, + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + self.json_schema.clone(), + self.xai_search_mode.clone(), + self.xai_search_source_type.clone(), + self.xai_search_excluded_websites.clone(), + self.xai_search_max_results, + self.xai_search_from_date.clone(), + self.xai_search_to_date.clone(), + ) + } else { + crate::backends::xai::XAI::new( + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + self.json_schema.clone(), + self.xai_search_mode.clone(), + self.xai_search_source_type.clone(), + self.xai_search_excluded_websites.clone(), + self.xai_search_max_results, + self.xai_search_from_date.clone(), + self.xai_search_to_date.clone(), + ) + }; Box::new(xai) } } @@ -823,15 +969,28 @@ impl LLMBuilder { #[cfg(feature = "phind")] { - let phind = crate::backends::phind::Phind::new( - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - ); + let phind = if let Some(client) = self.client.clone() { + crate::backends::phind::Phind::with_client( + client, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + ) + } else { + crate::backends::phind::Phind::new( + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + ) + }; Box::new(phind) } } @@ -843,22 +1002,38 @@ impl LLMBuilder { #[cfg(feature = "google")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for Google".to_string()) })?; - let google = crate::backends::google::Google::new( - api_key, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.json_schema, - tools, - ); + let google = if let Some(client) = self.client.clone() { + crate::backends::google::Google::with_client( + client, + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.json_schema.clone(), + tools.clone(), + ) + } else { + crate::backends::google::Google::new( + api_key, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.json_schema.clone(), + tools.clone(), + ) + }; Box::new(google) } } @@ -870,30 +1045,54 @@ impl LLMBuilder { #[cfg(feature = "groq")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for Groq".to_string()) })?; - let groq = crate::backends::groq::Groq::with_config( - api_key, - self.base_url, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.tools, - self.tool_choice, - self.extra_body, - None, // embedding_encoding_format - None, // embedding_dimensions - None, // reasoning_effort - self.json_schema, - self.enable_parallel_tool_use, - self.normalize_response, - ); + let groq = if let Some(client) = self.client.clone() { + crate::backends::groq::Groq::with_config_and_client( + client, + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + None, + None, + None, + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + } else { + crate::backends::groq::Groq::with_config( + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + None, + None, + None, + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + }; Box::new(groq) } } @@ -905,30 +1104,54 @@ impl LLMBuilder { #[cfg(feature = "openrouter")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for OpenRouter".to_string()) })?; - let openrouter = crate::backends::openrouter::OpenRouter::with_config( - api_key, - self.base_url, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.tools, - self.tool_choice, - self.extra_body, - None, // embedding_encoding_format - None, // embedding_dimensions - None, // reasoning_effort - self.json_schema, - self.enable_parallel_tool_use, - self.normalize_response, - ); + let openrouter = if let Some(client) = self.client.clone() { + crate::backends::openrouter::OpenRouter::with_config_and_client( + client, + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + None, + None, + None, + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + } else { + crate::backends::openrouter::OpenRouter::with_config( + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + None, + None, + None, + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + }; Box::new(openrouter) } } @@ -940,30 +1163,53 @@ impl LLMBuilder { #[cfg(feature = "cohere")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for Cohere".to_string()) })?; - let cohere = crate::backends::cohere::Cohere::new( - api_key, - self.base_url, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - tools, - self.tool_choice, - self.reasoning_effort, - self.json_schema, - None, - self.extra_body, - self.enable_parallel_tool_use, - self.normalize_response, - self.embedding_encoding_format, - self.embedding_dimensions, - ); + let cohere = if let Some(client) = self.client.clone() { + crate::backends::cohere::Cohere::with_config_and_client( + client, + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + self.reasoning_effort.clone(), + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + } else { + crate::backends::cohere::Cohere::with_config( + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + self.reasoning_effort.clone(), + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + }; Box::new(cohere) } } @@ -975,32 +1221,56 @@ impl LLMBuilder { #[cfg(feature = "huggingface")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest( "No API key provided for HuggingFace Inference Providers".to_string(), ) })?; - let llm = crate::backends::huggingface::HuggingFace::with_config( - api_key, - self.base_url, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.tools, - self.tool_choice, - self.extra_body, - None, // embedding_encoding_format - None, // embedding_dimensions - None, // reasoning_effort - self.json_schema, - self.enable_parallel_tool_use, - self.normalize_response, - ); + let llm = if let Some(client) = self.client.clone() { + crate::backends::huggingface::HuggingFace::with_config_and_client( + client, + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + None, + None, + None, + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + } else { + crate::backends::huggingface::HuggingFace::with_config( + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.tools.clone(), + self.tool_choice.clone(), + self.extra_body.clone(), + None, + None, + None, + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + }; Box::new(llm) } } @@ -1011,29 +1281,53 @@ impl LLMBuilder { )); #[cfg(feature = "mistral")] { - let api_key = self.api_key.ok_or_else(|| { + let api_key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for Mistral".to_string()) })?; - let mistral = crate::backends::mistral::Mistral::with_config( - api_key, - self.base_url, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - tools, - tool_choice, - self.extra_body, - self.embedding_encoding_format, - self.embedding_dimensions, - self.reasoning_effort, - self.json_schema, - self.enable_parallel_tool_use, - self.normalize_response, - ); + let mistral = if let Some(client) = self.client.clone() { + crate::backends::mistral::Mistral::with_config_and_client( + client, + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + tools.clone(), + tool_choice.clone(), + self.extra_body.clone(), + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + self.reasoning_effort.clone(), + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + } else { + crate::backends::mistral::Mistral::with_config( + api_key, + self.base_url.clone(), + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + tools.clone(), + tool_choice.clone(), + self.extra_body.clone(), + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + self.reasoning_effort.clone(), + self.json_schema.clone(), + self.enable_parallel_tool_use, + self.normalize_response, + ) + }; Box::new(mistral) } } @@ -1044,41 +1338,65 @@ impl LLMBuilder { )); #[cfg(feature = "azure_openai")] { - let endpoint = self.base_url.ok_or_else(|| { + let endpoint = self.base_url.clone().ok_or_else(|| { LLMError::InvalidRequest("No API endpoint provided for Azure OpenAI".into()) })?; - let key = self.api_key.ok_or_else(|| { + let key = self.api_key.clone().ok_or_else(|| { LLMError::InvalidRequest("No API key provided for Azure OpenAI".to_string()) })?; - let api_version = self.api_version.ok_or_else(|| { + let api_version = self.api_version.clone().ok_or_else(|| { LLMError::InvalidRequest( "No API version provided for Azure OpenAI".to_string(), ) })?; - let deployment = self.deployment_id.ok_or_else(|| { + let deployment = self.deployment_id.clone().ok_or_else(|| { LLMError::InvalidRequest( "No deployment ID provided for Azure OpenAI".into(), ) })?; - Box::new(crate::backends::azure_openai::AzureOpenAI::new( - key, - api_version, - deployment, - endpoint, - self.model, - self.max_tokens, - self.temperature, - self.timeout_seconds, - self.system, - self.top_p, - self.top_k, - self.embedding_encoding_format, - self.embedding_dimensions, - tools, - tool_choice, - self.reasoning_effort, - self.json_schema, - )) + let azure = if let Some(client) = self.client.clone() { + crate::backends::azure_openai::AzureOpenAI::with_client( + client, + key, + api_version, + deployment, + endpoint, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + tools.clone(), + tool_choice.clone(), + self.reasoning_effort.clone(), + self.json_schema.clone(), + ) + } else { + crate::backends::azure_openai::AzureOpenAI::new( + key, + api_version, + deployment, + endpoint, + self.model.clone(), + self.max_tokens, + self.temperature, + self.timeout_seconds, + self.system.clone(), + self.top_p, + self.top_k, + self.embedding_encoding_format.clone(), + self.embedding_dimensions, + tools.clone(), + tool_choice.clone(), + self.reasoning_effort.clone(), + self.json_schema.clone(), + ) + }; + Box::new(azure) } } }; @@ -1277,3 +1595,36 @@ impl FunctionBuilder { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_builder_client_method() { + let shared_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .expect("Failed to build client"); + + let builder = LLMBuilder::new() + .backend(LLMBackend::OpenAI) + .client(shared_client) + .api_key("test-key") + .model("gpt-4"); + + // Verify the client was set + assert!(builder.client.is_some()); + } + + #[test] + fn test_builder_without_client() { + let builder = LLMBuilder::new() + .backend(LLMBackend::OpenAI) + .api_key("test-key") + .model("gpt-4"); + + // Verify no client was set (will create default) + assert!(builder.client.is_none()); + } +} diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index 8aac07e..5546973 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -27,6 +27,7 @@ use std::pin::Pin; /// /// This struct provides a base implementation for any OpenAI-compatible API. /// Different providers can customize behavior by implementing the `OpenAICompatibleConfig` trait. +#[derive(Clone)] pub struct OpenAICompatibleProvider { pub api_key: String, pub base_url: Url, @@ -351,6 +352,68 @@ impl OpenAICompatibleProvider { } } + /// Creates a new provider with a pre-configured HTTP client. + /// + /// This allows sharing a single `reqwest::Client` across multiple providers, + /// enabling connection pooling and reducing resource usage. + /// + /// # Arguments + /// + /// * `client` - A pre-configured `reqwest::Client` to use for HTTP requests + /// * Other arguments are the same as `new()` + #[allow(clippy::too_many_arguments)] + pub fn with_client( + client: Client, + api_key: impl Into, + base_url: Option, + model: Option, + max_tokens: Option, + temperature: Option, + timeout_seconds: Option, + system: Option, + top_p: Option, + top_k: Option, + tools: Option>, + tool_choice: Option, + reasoning_effort: Option, + json_schema: Option, + voice: Option, + extra_body: Option, + parallel_tool_calls: Option, + normalize_response: Option, + embedding_encoding_format: Option, + embedding_dimensions: Option, + ) -> Self { + let extra_body = match extra_body { + Some(serde_json::Value::Object(map)) => map, + _ => serde_json::Map::new(), + }; + Self { + api_key: api_key.into(), + base_url: Url::parse(&format!("{}/", base_url.unwrap_or_else(|| T::DEFAULT_BASE_URL.to_owned()).trim_end_matches("/"))) + .expect("Failed to parse base URL"), + model: model.unwrap_or_else(|| T::DEFAULT_MODEL.to_string()), + max_tokens, + temperature, + system, + timeout_seconds, + top_p, + top_k, + tools, + tool_choice, + reasoning_effort, + json_schema, + voice, + extra_body, + parallel_tool_calls: parallel_tool_calls.unwrap_or(false), + normalize_response: normalize_response.unwrap_or(true), + embedding_encoding_format, + embedding_dimensions, + client, + _phantom: PhantomData, + } + } + pub fn prepare_messages(&self, messages: &[ChatMessage]) -> Vec> { let mut openai_msgs: Vec = messages .iter() @@ -1417,4 +1480,115 @@ mod tests { results[0] ); } + + /// Test config for Clone tests + #[derive(Clone)] + struct TestConfig; + + impl OpenAIProviderConfig for TestConfig { + const PROVIDER_NAME: &'static str = "Test"; + const DEFAULT_BASE_URL: &'static str = "https://api.test.com/v1/"; + const DEFAULT_MODEL: &'static str = "test-model"; + const SUPPORTS_REASONING_EFFORT: bool = false; + const SUPPORTS_STRUCTURED_OUTPUT: bool = true; + const SUPPORTS_PARALLEL_TOOL_CALLS: bool = false; + } + + #[test] + fn test_openai_compatible_provider_clone() { + let provider = OpenAICompatibleProvider::::new( + "test-api-key", + None, + Some("gpt-4".to_string()), + Some(1000), + Some(0.7), + Some(30), + Some("You are helpful.".to_string()), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ); + + // Clone the provider + let cloned = provider.clone(); + + // Verify both have the same configuration + assert_eq!(provider.api_key, cloned.api_key); + assert_eq!(provider.model, cloned.model); + assert_eq!(provider.max_tokens, cloned.max_tokens); + assert_eq!(provider.temperature, cloned.temperature); + assert_eq!(provider.system, cloned.system); + } + + #[test] + fn test_openai_compatible_provider_with_client() { + let shared_client = Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .expect("Failed to build client"); + + let provider = OpenAICompatibleProvider::::with_client( + shared_client.clone(), + "test-api-key", + None, + Some("gpt-4".to_string()), + Some(1000), + Some(0.7), + Some(30), + Some("You are helpful.".to_string()), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ); + + // Verify configuration + assert_eq!(provider.api_key, "test-api-key"); + assert_eq!(provider.model, "gpt-4"); + assert_eq!(provider.max_tokens, Some(1000)); + + // Create another provider with the same client + let provider2 = OpenAICompatibleProvider::::with_client( + shared_client, + "test-api-key-2", + None, + Some("gpt-3.5-turbo".to_string()), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ); + + assert_eq!(provider2.api_key, "test-api-key-2"); + assert_eq!(provider2.model, "gpt-3.5-turbo"); + } }