Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 221 additions & 54 deletions src/backends/anthropic.rs

Large diffs are not rendered by default.

236 changes: 191 additions & 45 deletions src/backends/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! This module provides integration with Azure OpenAI's GPT models through their API.

use std::sync::Arc;

#[cfg(feature = "azure_openai")]
use crate::{
builder::LLMBackend,
Expand All @@ -24,29 +26,54 @@ use either::*;
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};

/// Client for interacting with Azure OpenAI's API.
///
/// Provides methods for chat and completion requests using Azure OpenAI's models.
pub struct AzureOpenAI {
/// Configuration for the Azure OpenAI client.
#[derive(Debug)]
pub struct AzureOpenAIConfig {
/// API key for authentication.
pub api_key: String,
/// API version string.
pub api_version: String,
/// Base URL for API requests.
pub base_url: Url,
/// Model identifier.
pub model: String,
/// Maximum tokens to generate in responses.
pub max_tokens: Option<u32>,
/// Sampling temperature for response randomness.
pub temperature: Option<f32>,
/// System prompt to guide model behavior.
pub system: Option<String>,
/// Request timeout in seconds.
pub timeout_seconds: Option<u64>,
/// Top-p (nucleus) sampling parameter.
pub top_p: Option<f32>,
/// Top-k sampling parameter.
pub top_k: Option<u32>,
/// Available tools for the model to use.
pub tools: Option<Vec<Tool>>,
/// Tool choice configuration.
pub tool_choice: Option<ToolChoice>,
/// Embedding parameters
/// Encoding format for embeddings.
pub embedding_encoding_format: Option<String>,
/// Dimensions for embeddings.
pub embedding_dimensions: Option<u32>,
/// Reasoning effort level.
pub reasoning_effort: Option<String>,
/// JSON schema for structured output
/// JSON schema for structured output.
pub json_schema: Option<StructuredOutputFormat>,
client: Client,
}

/// Client for interacting with Azure OpenAI's API.
///
/// Provides methods for chat and completion requests using Azure OpenAI's models.
///
/// The client uses `Arc` internally for configuration, making cloning cheap.
#[derive(Debug, Clone)]
pub struct AzureOpenAI {
/// Shared configuration wrapped in Arc for cheap cloning.
pub config: Arc<AzureOpenAIConfig>,
/// HTTP client for making requests.
pub client: Client,
}

/// Individual message in an OpenAI chat conversation.
Expand Down Expand Up @@ -356,31 +383,144 @@ impl AzureOpenAI {
if let Some(sec) = timeout_seconds {
builder = builder.timeout(std::time::Duration::from_secs(sec));
}

let endpoint = endpoint.into();
let deployment_id = deployment_id.into();

Self {
api_key: api_key.into(),
api_version: api_version.into(),
base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/"))
.expect("Failed to parse base Url"),
model: model.unwrap_or("gpt-3.5-turbo".to_string()),
Self::with_client(
builder.build().expect("Failed to build reqwest Client"),
api_key,
api_version,
deployment_id,
endpoint,
model,
max_tokens,
temperature,
system,
timeout_seconds,
system,
top_p,
top_k,
tools,
tool_choice,
embedding_encoding_format,
embedding_dimensions,
client: builder.build().expect("Failed to build reqwest Client"),
tools,
tool_choice,
reasoning_effort,
json_schema,
)
}

/// Creates a new Azure OpenAI client with a custom HTTP client.
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
api_version: impl Into<String>,
deployment_id: impl Into<String>,
endpoint: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
reasoning_effort: Option<String>,
json_schema: Option<StructuredOutputFormat>,
) -> Self {
let endpoint = endpoint.into();
let deployment_id = deployment_id.into();

Self {
config: Arc::new(AzureOpenAIConfig {
api_key: api_key.into(),
api_version: api_version.into(),
base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/"))
.expect("Failed to parse base Url"),
model: model.unwrap_or("gpt-3.5-turbo".to_string()),
max_tokens,
temperature,
system,
timeout_seconds,
top_p,
top_k,
tools,
tool_choice,
embedding_encoding_format,
embedding_dimensions,
reasoning_effort,
json_schema,
}),
client,
}
}

pub fn api_key(&self) -> &str {
&self.config.api_key
}

pub fn api_version(&self) -> &str {
&self.config.api_version
}

pub fn base_url(&self) -> &Url {
&self.config.base_url
}

pub fn model(&self) -> &str {
&self.config.model
}

pub fn max_tokens(&self) -> Option<u32> {
self.config.max_tokens
}

pub fn temperature(&self) -> Option<f32> {
self.config.temperature
}

pub fn timeout_seconds(&self) -> Option<u64> {
self.config.timeout_seconds
}

pub fn system(&self) -> Option<&str> {
self.config.system.as_deref()
}

pub fn top_p(&self) -> Option<f32> {
self.config.top_p
}

pub fn top_k(&self) -> Option<u32> {
self.config.top_k
}

pub fn tools(&self) -> Option<&[Tool]> {
self.config.tools.as_deref()
}

pub fn tool_choice(&self) -> Option<&ToolChoice> {
self.config.tool_choice.as_ref()
}

pub fn embedding_encoding_format(&self) -> Option<&str> {
self.config.embedding_encoding_format.as_deref()
}

pub fn embedding_dimensions(&self) -> Option<u32> {
self.config.embedding_dimensions
}

pub fn reasoning_effort(&self) -> Option<&str> {
self.config.reasoning_effort.as_deref()
}

pub fn json_schema(&self) -> Option<&StructuredOutputFormat> {
self.config.json_schema.as_ref()
}

pub fn client(&self) -> &Client {
&self.client
}
}

#[async_trait]
Expand All @@ -399,7 +539,7 @@ impl ChatProvider for AzureOpenAI {
messages: &[ChatMessage],
tools: Option<&[Tool]>,
) -> Result<Box<dyn ChatResponse>, LLMError> {
if self.api_key.is_empty() {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError(
"Missing Azure OpenAI API key".to_string(),
));
Expand All @@ -425,7 +565,7 @@ impl ChatProvider for AzureOpenAI {
}
}

if let Some(system) = &self.system {
if let Some(system) = &self.config.system {
openai_msgs.insert(
0,
AzureOpenAIChatMessage {
Expand All @@ -445,26 +585,28 @@ impl ChatProvider for AzureOpenAI {

// Build the response format object
let response_format: Option<OpenAIResponseFormat> =
self.json_schema.clone().map(|s| s.into());
self.config.json_schema.clone().map(|s| s.into());

let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
let request_tools = tools
.map(|t| t.to_vec())
.or_else(|| self.config.tools.clone());
let request_tool_choice = if request_tools.is_some() {
self.tool_choice.clone()
self.config.tool_choice.clone()
} else {
None
};

let body = AzureOpenAIChatRequest {
model: &self.model,
model: &self.config.model,
messages: openai_msgs,
max_tokens: self.max_tokens,
temperature: self.temperature,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
stream: false,
top_p: self.top_p,
top_k: self.top_k,
top_p: self.config.top_p,
top_k: self.config.top_k,
tools: request_tools,
tool_choice: request_tool_choice,
reasoning_effort: self.reasoning_effort.clone(),
reasoning_effort: self.config.reasoning_effort.clone(),
response_format,
};

Expand All @@ -475,20 +617,21 @@ impl ChatProvider for AzureOpenAI {
}

let mut url = self
.config
.base_url
.join("chat/completions")
.map_err(|e| LLMError::HttpError(e.to_string()))?;

url.query_pairs_mut()
.append_pair("api-version", &self.api_version);
.append_pair("api-version", &self.config.api_version);

let mut request = self
.client
.post(url)
.header("api-key", &self.api_key)
.header("api-key", &self.config.api_key)
.json(&body);

if let Some(timeout) = self.timeout_seconds {
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}

Expand Down Expand Up @@ -542,34 +685,36 @@ impl CompletionProvider for AzureOpenAI {
#[async_trait]
impl EmbeddingProvider for AzureOpenAI {
async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
if self.api_key.is_empty() {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing OpenAI API key".into()));
}

let emb_format = self
.config
.embedding_encoding_format
.clone()
.unwrap_or_else(|| "float".to_string());

let body = OpenAIEmbeddingRequest {
model: self.model.clone(),
model: self.config.model.clone(),
input,
encoding_format: Some(emb_format),
dimensions: self.embedding_dimensions,
dimensions: self.config.embedding_dimensions,
};

let mut url = self
.config
.base_url
.join("embeddings")
.map_err(|e| LLMError::HttpError(e.to_string()))?;

url.query_pairs_mut()
.append_pair("api-version", &self.api_version);
.append_pair("api-version", &self.config.api_version);

let resp = self
.client
.post(url)
.header("api-key", &self.api_key)
.header("api-key", &self.config.api_key)
.json(&body)
.send()
.await?
Expand All @@ -584,7 +729,7 @@ impl EmbeddingProvider for AzureOpenAI {

impl LLMProvider for AzureOpenAI {
fn tools(&self) -> Option<&[Tool]> {
self.tools.as_deref()
self.config.tools.as_deref()
}
}

Expand Down Expand Up @@ -612,23 +757,24 @@ impl ModelsProvider for AzureOpenAI {
&self,
_request: Option<&ModelListRequest>,
) -> Result<Box<dyn ModelListResponse>, LLMError> {
if self.api_key.is_empty() {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError(
"Missing Azure OpenAI API key".to_string(),
));
}

let mut url = self
.config
.base_url
.join("models")
.map_err(|e| LLMError::HttpError(e.to_string()))?;

url.query_pairs_mut()
.append_pair("api-version", &self.api_version);
.append_pair("api-version", &self.config.api_version);

let mut request = self.client.get(url).header("api-key", &self.api_key);
let mut request = self.client.get(url).header("api-key", &self.config.api_key);

if let Some(timeout) = self.timeout_seconds {
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}

Expand Down
Loading