Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/artificial-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "artificial-core"
version = "0.4.2"
version = "0.5.0"
edition = "2021"
description = "Provider-agnostic core traits, generic client and error types for the Artificial prompt-engineering SDK"
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions crates/artificial-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ impl<B: ChatCompletionProvider> ChatCompletionProvider for ArtificialClient<B> {
}

impl<B: StreamingChatProvider> StreamingChatProvider for ArtificialClient<B> {
type Message = B::Message;

type Delta<'s>
= B::Delta<'s>
where
Expand Down
5 changes: 4 additions & 1 deletion crates/artificial-core/src/provider/chat_complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ pub trait ChatCompletionProvider: Send + Sync {
/// Tool-call and richer payload support can be layered on later by
/// introducing a dedicated enum – starting with plain text keeps the API
/// minimal and backend-agnostic.
pub trait StreamingChatProvider: ChatCompletionProvider {
pub trait StreamingChatProvider: Send + Sync {
/// Chat message type consumed by this backend.
type Message: Send + Sync + 'static;

/// The item type returned on the stream. For now it is plain UTF-8 text
/// chunks, but back-ends are free to wrap it in richer enums if needed.
type Delta<'s>: Stream<Item = Result<String>> + Send + 's
Expand Down
9 changes: 7 additions & 2 deletions crates/artificial-openai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "artificial-openai"
version = "0.4.2"
version = "0.5.0"
edition = "2024"
description = "OpenAI backend adapter for the Artificial prompt-engineering SDK"
license = "MIT"
Expand All @@ -24,7 +24,12 @@ reqwest = { version = "0.12", default-features = false, features = [
"deflate",
"multipart",
] }
artificial-core = { path = "../artificial-core" , version = "0.4.2"}
artificial-core = { path = "../artificial-core" , version = "0.5.0"}
futures-util = "0.3"
async-stream = "0.3"
bytes = "1"
tracing = { version = "0.1", optional = true }

[features]
default = []
tracing = ["dep:tracing"]
24 changes: 20 additions & 4 deletions crates/artificial-openai/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{env, sync::Arc};

use artificial_core::error::{ArtificialError, Result};

use crate::client::OpenAiClient;
use crate::client::{OpenAiClient, RetryPolicy};

/// Thin wrapper that wires the HTTP client [`OpenAiClient`] into a value that
/// implements [`artificial_core::backend::Backend`].
Expand Down Expand Up @@ -40,6 +40,7 @@ impl OpenAiAdapter {}
#[derive(Default)]
pub struct OpenAiAdapterBuilder {
pub(crate) api_key: Option<String>,
pub(crate) retry: Option<RetryPolicy>,
}

impl OpenAiAdapterBuilder {
Expand All @@ -57,19 +58,34 @@ impl OpenAiAdapterBuilder {
pub fn new_from_env() -> Self {
Self {
api_key: env::var("OPENAI_API_KEY").ok(),
retry: None,
}
}

/// Set a retry policy for OpenAI HTTP calls.
pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
self.retry = Some(retry);
self
}

/// Finalise the builder and return a ready-to-use adapter.
///
/// # Errors
///
/// * [`ArtificialError::Invalid`] – if the API key is missing.
pub fn build(self) -> Result<OpenAiAdapter> {
let api_key = self.api_key.ok_or(ArtificialError::Invalid(
"missing env variable: `OPENAI_API_KEY`".into(),
))?;

let client = if let Some(retry) = self.retry {
OpenAiClient::new(api_key).with_retry_policy(retry)
} else {
OpenAiClient::new(api_key)
};

Ok(OpenAiAdapter {
client: Arc::new(OpenAiClient::new(self.api_key.ok_or(
ArtificialError::Invalid("missing env variable: `OPENAI_API_KEY`".into()),
)?)),
client: Arc::new(client),
})
}
}
255 changes: 234 additions & 21 deletions crates/artificial-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,122 @@ use std::time::Duration;

use crate::{
api_v1::{ChatCompletionChunkResponse, ChatCompletionRequest, ChatCompletionResponse},
error::OpenAiError,
error::{OpenAiError, OpenAiRateLimitHeaders},
};

fn parse_retry_after_seconds(headers: &reqwest::header::HeaderMap) -> Duration {
use reqwest::header::RETRY_AFTER;
if let Some(val) = headers.get(RETRY_AFTER).and_then(|hv| hv.to_str().ok())
&& let Ok(secs) = val.trim().parse::<u64>()
{
return Duration::from_secs(secs);
}
Duration::from_secs(0)
}

fn header_u32(headers: &reqwest::header::HeaderMap, name: &str) -> Option<u32> {
headers
.get(name)
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<u32>().ok())
}

fn header_string(headers: &reqwest::header::HeaderMap, name: &str) -> Option<String> {
headers
.get(name)
.and_then(|hv| hv.to_str().ok())
.map(|s| s.to_string())
}

fn extract_rate_limit_info(
headers: &reqwest::header::HeaderMap,
) -> (Option<Duration>, Option<String>, OpenAiRateLimitHeaders) {
let retry_after = {
let d = parse_retry_after_seconds(headers);
if d.as_secs() > 0 { Some(d) } else { None }
};

let info = OpenAiRateLimitHeaders {
limit_requests: header_u32(headers, "x-ratelimit-limit-requests"),
remaining_requests: header_u32(headers, "x-ratelimit-remaining-requests"),
reset_requests: header_string(headers, "x-ratelimit-reset-requests"),
limit_tokens: header_u32(headers, "x-ratelimit-limit-tokens"),
remaining_tokens: header_u32(headers, "x-ratelimit-remaining-tokens"),
reset_tokens: header_string(headers, "x-ratelimit-reset-tokens"),
};

// Prefer request reset, fall back to token reset.
let reset_at = info
.reset_requests
.clone()
.or_else(|| info.reset_tokens.clone());

(retry_after, reset_at, info)
}
#[cfg(feature = "tracing")]
fn log_rate_limit_tight(headers: &reqwest::header::HeaderMap, context: &str) {
let rem_reqs = header_u32(headers, "x-ratelimit-remaining-requests").unwrap_or(u32::MAX);
let rem_tokens = header_u32(headers, "x-ratelimit-remaining-tokens").unwrap_or(u32::MAX);
let lim_reqs = header_u32(headers, "x-ratelimit-limit-requests").unwrap_or(0);
let lim_tokens = header_u32(headers, "x-ratelimit-limit-tokens").unwrap_or(0);

// Heuristics: warn when headroom is tight
let tight_reqs = rem_reqs <= 2 || (lim_reqs > 0 && rem_reqs as f32 / lim_reqs as f32 <= 0.05);
let tight_tokens =
rem_tokens <= 128 || (lim_tokens > 0 && rem_tokens as f32 / lim_tokens as f32 <= 0.05);

if tight_reqs || tight_tokens {
tracing::warn!(
context,
remaining_requests = rem_reqs,
limit_requests = lim_reqs,
remaining_tokens = rem_tokens,
limit_tokens = lim_tokens,
"rate limit headroom is tight"
);
} else {
tracing::debug!(
context,
remaining_requests = rem_reqs,
limit_requests = lim_reqs,
remaining_tokens = rem_tokens,
limit_tokens = lim_tokens,
"rate limit status"
);
}
}

#[derive(Clone, Debug)]
pub struct RetryPolicy {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub respect_retry_after: bool,
}

impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
respect_retry_after: true,
}
}
}

impl RetryPolicy {
fn backoff_for(&self, attempt: u32) -> Duration {
let pow = attempt.min(10);
let backoff = self.base_delay.saturating_mul(1 << pow);
if backoff > self.max_delay {
self.max_delay
} else {
backoff
}
}
}

const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";

/// Minimal HTTP client for OpenAI’s *chat/completions* endpoint.
Expand All @@ -26,6 +139,7 @@ pub struct OpenAiClient {
api_key: String,
http: HttpClient,
base: String,
retry: RetryPolicy,
}

impl OpenAiClient {
Expand All @@ -51,6 +165,123 @@ impl OpenAiClient {
api_key: api_key.into(),
http,
base: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned()),
retry: RetryPolicy::default(),
}
}

/// Allow callers to override the default retry policy.
pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
self.retry = retry;
self
}

// Internal: send POST with retry/backoff handling.
async fn post_json_with_retry(
&self,
url: String,
headers: HeaderMap,
request: &ChatCompletionRequest,
) -> Result<reqwest::Response, OpenAiError> {
let mut attempt: u32 = 0;
loop {
let res = self
.http
.post(url.clone())
.headers(headers.clone())
.json(request)
.send()
.await;

match res {
Ok(resp) => {
let status = resp.status();
if status.is_success() {
#[cfg(feature = "tracing")]
{
log_rate_limit_tight(resp.headers(), "success");
}
return Ok(resp);
}

let should_retry = status == reqwest::StatusCode::TOO_MANY_REQUESTS
|| status.is_server_error();

if should_retry && attempt < self.retry.max_retries {
let mut delay = self.retry.backoff_for(attempt);
#[allow(unused_assignments)]
let mut hdr_delay = Duration::from_secs(0);
if self.retry.respect_retry_after {
hdr_delay = parse_retry_after_seconds(resp.headers());
if hdr_delay > delay {
delay = hdr_delay;
}
}
#[cfg(feature = "tracing")]
{
tracing::info!(
attempt = attempt,
status = %status,
backoff_ms = delay.as_millis() as u64,
retry_after_ms = hdr_delay.as_millis() as u64,
"retrying request due to transient status"
);
log_rate_limit_tight(resp.headers(), "retrying");
}
// Blocking sleep to avoid introducing a new async runtime dependency.
std::thread::sleep(delay);
attempt += 1;
continue;
} else {
let status = resp.status();
let headers_map = resp.headers().clone();
let body = resp.text().await.unwrap_or_default();
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let (retry_after, reset_at, headers) =
extract_rate_limit_info(&headers_map);
#[cfg(feature = "tracing")]
{
let ra_ms = retry_after.map(|d| d.as_millis() as u64);
tracing::warn!(
status = %status,
retry_after_ms = ?ra_ms,
reset_at = ?reset_at,
"rate limited; giving up after retries"
);
}
return Err(OpenAiError::RateLimited {
status,
body,
retry_after,
reset_at,
headers,
});
} else {
return Err(OpenAiError::Api { status, body });
}
}
}
Err(err) => {
// Retry on transport errors up to max_retries.
if attempt < self.retry.max_retries
&& (err.is_timeout() || err.is_connect() || !err.is_status())
{
let delay = self.retry.backoff_for(attempt);
#[cfg(feature = "tracing")]
{
tracing::info!(
attempt = attempt,
backoff_ms = delay.as_millis() as u64,
"retrying after transport error"
);
}
std::thread::sleep(delay);
attempt += 1;
continue;
} else {
return Err(OpenAiError::Http(err));
}
}
}
}
}

Expand All @@ -68,19 +299,7 @@ impl OpenAiClient {
);

let url = format!("{}/chat/completions", self.base);
let resp = self
.http
.post(url)
.headers(headers)
.json(&request)
.send()
.await?;

if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(OpenAiError::Api { status, body });
}
let resp = self.post_json_with_retry(url, headers, &request).await?;

let bytes = resp.bytes().await?;
let parsed: ChatCompletionResponse = serde_json::from_slice(&bytes)?;
Expand Down Expand Up @@ -110,13 +329,7 @@ impl OpenAiClient {

// 3) async stream wrapper
try_stream! {
let resp = self.http.post(url).headers(headers).json(&request).send().await?;

if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(OpenAiError::Api { status, body })?;
}
let resp = self.post_json_with_retry(url, headers, &request).await?;

let mut bytes_stream = resp.bytes_stream();
let mut buf = Vec::new();
Expand Down
Loading
Loading