Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion src/backends/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use serde_json::Value;
/// Client for interacting with Anthropic's API.
///
/// Provides methods for chat and completion requests using Anthropic's models.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Anthropic {
pub api_key: String,
pub model: String,
Expand Down Expand Up @@ -506,6 +506,43 @@ impl Anthropic {
client: builder.build().expect("Failed to build reqwest Client"),
}
}

/// Creates a new Anthropic client with a pre-configured HTTP client.
///
/// This allows sharing a single `reqwest::Client` across multiple providers,
/// enabling connection pooling and reducing resource usage.
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
reasoning: Option<bool>,
thinking_budget_tokens: Option<u32>,
) -> Self {
Self {
api_key: api_key.into(),
model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()),
max_tokens: max_tokens.unwrap_or(300),
temperature: temperature.unwrap_or(0.7),
system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()),
timeout_seconds: timeout_seconds.unwrap_or(30),
top_p,
top_k,
tools,
tool_choice,
reasoning: reasoning.unwrap_or(false),
thinking_budget_tokens,
client,
}
}
}

#[async_trait]
Expand Down Expand Up @@ -1391,4 +1428,81 @@ data: {"type": "ping"}
let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap();
assert!(result.is_none());
}

#[test]
fn test_anthropic_clone() {
let anthropic = Anthropic::new(
"test-api-key",
Some("claude-3-sonnet".to_string()),
Some(1000),
Some(0.7),
Some(30),
Some("You are helpful.".to_string()),
None,
None,
None,
None,
None,
None,
);

// Clone the provider
let cloned = anthropic.clone();

// Verify both have the same configuration
assert_eq!(anthropic.api_key, cloned.api_key);
assert_eq!(anthropic.model, cloned.model);
assert_eq!(anthropic.max_tokens, cloned.max_tokens);
assert_eq!(anthropic.temperature, cloned.temperature);
assert_eq!(anthropic.system, cloned.system);
}

#[test]
fn test_anthropic_with_client() {
let shared_client = Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.expect("Failed to build client");

let anthropic = Anthropic::with_client(
shared_client.clone(),
"test-api-key",
Some("claude-3-sonnet".to_string()),
Some(1000),
Some(0.7),
Some(30),
Some("You are helpful.".to_string()),
None,
None,
None,
None,
None,
None,
);

// Verify configuration
assert_eq!(anthropic.api_key, "test-api-key");
assert_eq!(anthropic.model, "claude-3-sonnet");
assert_eq!(anthropic.max_tokens, 1000);

// Create another provider with the same client
let anthropic2 = Anthropic::with_client(
shared_client,
"test-api-key-2",
Some("claude-3-haiku".to_string()),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
);

assert_eq!(anthropic2.api_key, "test-api-key-2");
assert_eq!(anthropic2.model, "claude-3-haiku");
}
}
51 changes: 51 additions & 0 deletions src/backends/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use serde::{Deserialize, Serialize};
/// Client for interacting with Azure OpenAI's API.
///
/// Provides methods for chat and completion requests using Azure OpenAI's models.
#[derive(Clone)]
pub struct AzureOpenAI {
pub api_key: String,
pub api_version: String,
Expand Down Expand Up @@ -381,6 +382,56 @@ impl AzureOpenAI {
json_schema,
}
}

/// Creates a new Azure OpenAI client with a pre-configured HTTP client.
///
/// This allows sharing a single `reqwest::Client` across multiple providers,
/// enabling connection pooling and reducing resource usage.
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
api_version: impl Into<String>,
deployment_id: impl Into<String>,
endpoint: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
reasoning_effort: Option<String>,
json_schema: Option<StructuredOutputFormat>,
) -> Self {
let endpoint = endpoint.into();
let deployment_id = deployment_id.into();

Self {
api_key: api_key.into(),
api_version: api_version.into(),
base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/"))
.expect("Failed to parse base Url"),
model: model.unwrap_or("gpt-3.5-turbo".to_string()),
max_tokens,
temperature,
system,
timeout_seconds,
top_p,
top_k,
tools,
tool_choice,
embedding_encoding_format,
embedding_dimensions,
client,
reasoning_effort,
json_schema,
}
}
}

#[async_trait]
Expand Down
48 changes: 48 additions & 0 deletions src/backends/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};

/// Cohere configuration for the generic provider
#[derive(Clone)]
pub struct CohereConfig;

impl OpenAIProviderConfig for CohereConfig {
Expand Down Expand Up @@ -78,6 +79,53 @@ impl Cohere {
embedding_dimensions,
)
}

/// Creates a new Cohere client with a pre-configured HTTP client.
#[allow(clippy::too_many_arguments)]
pub fn with_config_and_client(
client: reqwest::Client,
api_key: impl Into<String>,
base_url: Option<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
extra_body: Option<serde_json::Value>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
reasoning_effort: Option<String>,
json_schema: Option<StructuredOutputFormat>,
parallel_tool_calls: Option<bool>,
normalize_response: Option<bool>,
) -> Self {
<OpenAICompatibleProvider<CohereConfig>>::with_client(
client,
api_key,
base_url,
model,
max_tokens,
temperature,
timeout_seconds,
system,
top_p,
top_k,
tools,
tool_choice,
reasoning_effort,
json_schema,
None,
extra_body,
parallel_tool_calls,
normalize_response,
embedding_encoding_format,
embedding_dimensions,
)
}
}

// Cohere-specific implementations that don't fit in the generic OpenAI-compatible provider
Expand Down
25 changes: 25 additions & 0 deletions src/backends/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use serde::{Deserialize, Serialize};

use crate::ToolCall;

#[derive(Clone)]
pub struct DeepSeek {
pub api_key: String,
pub model: String,
Expand Down Expand Up @@ -105,6 +106,30 @@ impl DeepSeek {
client: builder.build().expect("Failed to build reqwest Client"),
}
}

/// Creates a new DeepSeek client with a pre-configured HTTP client.
///
/// This allows sharing a single `reqwest::Client` across multiple providers,
/// enabling connection pooling and reducing resource usage.
pub fn with_client(
client: Client,
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
) -> Self {
Self {
api_key: api_key.into(),
model: model.unwrap_or("deepseek-chat".to_string()),
max_tokens,
temperature,
system,
timeout_seconds,
client,
}
}
}

#[async_trait]
Expand Down
29 changes: 28 additions & 1 deletion src/backends/elevenlabs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::time::Duration;
///
/// This struct provides functionality for speech-to-text transcription using the ElevenLabs API.
/// It implements various LLM provider traits but only supports speech-to-text functionality.
#[derive(Clone)]
pub struct ElevenLabs {
/// API key for ElevenLabs authentication
api_key: String,
Expand Down Expand Up @@ -91,13 +92,39 @@ impl ElevenLabs {
base_url: String,
timeout_seconds: Option<u64>,
voice: Option<String>,
) -> Self {
let mut builder = Client::builder();
if let Some(sec) = timeout_seconds {
builder = builder.timeout(Duration::from_secs(sec));
}
Self {
api_key,
model_id,
base_url,
timeout_seconds,
client: builder.build().expect("Failed to build reqwest Client"),
voice,
}
}

/// Creates a new ElevenLabs client with a pre-configured HTTP client.
///
/// This allows sharing a single `reqwest::Client` across multiple providers,
/// enabling connection pooling and reducing resource usage.
pub fn with_client(
client: Client,
api_key: String,
model_id: String,
base_url: String,
timeout_seconds: Option<u64>,
voice: Option<String>,
) -> Self {
Self {
api_key,
model_id,
base_url,
timeout_seconds,
client: Client::new(),
client,
voice,
}
}
Expand Down
34 changes: 34 additions & 0 deletions src/backends/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use serde_json::Value;
///
/// This struct holds the configuration and state needed to make requests to the Gemini API.
/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits.
#[derive(Clone)]
pub struct Google {
/// API key for authentication with Google's API
pub api_key: String,
Expand Down Expand Up @@ -514,6 +515,39 @@ impl Google {
client: builder.build().expect("Failed to build reqwest Client"),
}
}

/// Creates a new Google client with a pre-configured HTTP client.
///
/// This allows sharing a single `reqwest::Client` across multiple providers,
/// enabling connection pooling and reducing resource usage.
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
json_schema: Option<StructuredOutputFormat>,
tools: Option<Vec<Tool>>,
) -> Self {
Self {
api_key: api_key.into(),
model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()),
max_tokens,
temperature,
system,
timeout_seconds,
top_p,
top_k,
json_schema,
tools,
client,
}
}
}

#[async_trait]
Expand Down
Loading