diff --git a/crates/artificial-core/Cargo.toml b/crates/artificial-core/Cargo.toml index 116a98d..5003fff 100644 --- a/crates/artificial-core/Cargo.toml +++ b/crates/artificial-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "artificial-core" -version = "0.4.2" +version = "0.5.0" edition = "2021" description = "Provider-agnostic core traits, generic client and error types for the Artificial prompt-engineering SDK" license = "MIT" diff --git a/crates/artificial-core/src/client.rs b/crates/artificial-core/src/client.rs index 655a163..14cee73 100644 --- a/crates/artificial-core/src/client.rs +++ b/crates/artificial-core/src/client.rs @@ -106,6 +106,8 @@ impl ChatCompletionProvider for ArtificialClient { } impl StreamingChatProvider for ArtificialClient { + type Message = B::Message; + type Delta<'s> = B::Delta<'s> where diff --git a/crates/artificial-core/src/provider/chat_complete.rs b/crates/artificial-core/src/provider/chat_complete.rs index 216a3c7..f48d3b2 100644 --- a/crates/artificial-core/src/provider/chat_complete.rs +++ b/crates/artificial-core/src/provider/chat_complete.rs @@ -39,7 +39,10 @@ pub trait ChatCompletionProvider: Send + Sync { /// Tool-call and richer payload support can be layered on later by /// introducing a dedicated enum – starting with plain text keeps the API /// minimal and backend-agnostic. -pub trait StreamingChatProvider: ChatCompletionProvider { +pub trait StreamingChatProvider: Send + Sync { + /// Chat message type consumed by this backend. + type Message: Send + Sync + 'static; + /// The item type returned on the stream. For now it is plain UTF-8 text /// chunks, but back-ends are free to wrap it in richer enums if needed. type Delta<'s>: Stream> + Send + 's diff --git a/crates/artificial-openai/Cargo.toml b/crates/artificial-openai/Cargo.toml index 0fec55e..eb13812 100644 --- a/crates/artificial-openai/Cargo.toml +++ b/crates/artificial-openai/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "artificial-openai" -version = "0.4.2" +version = "0.5.0" edition = "2024" description = "OpenAI backend adapter for the Artificial prompt-engineering SDK" license = "MIT" @@ -24,7 +24,12 @@ reqwest = { version = "0.12", default-features = false, features = [ "deflate", "multipart", ] } -artificial-core = { path = "../artificial-core" , version = "0.4.2"} +artificial-core = { path = "../artificial-core" , version = "0.5.0"} futures-util = "0.3" async-stream = "0.3" bytes = "1" +tracing = { version = "0.1", optional = true } + +[features] +default = [] +tracing = ["dep:tracing"] diff --git a/crates/artificial-openai/src/adapter.rs b/crates/artificial-openai/src/adapter.rs index 021038b..463152d 100644 --- a/crates/artificial-openai/src/adapter.rs +++ b/crates/artificial-openai/src/adapter.rs @@ -2,7 +2,7 @@ use std::{env, sync::Arc}; use artificial_core::error::{ArtificialError, Result}; -use crate::client::OpenAiClient; +use crate::client::{OpenAiClient, RetryPolicy}; /// Thin wrapper that wires the HTTP client [`OpenAiClient`] into a value that /// implements [`artificial_core::backend::Backend`]. @@ -40,6 +40,7 @@ impl OpenAiAdapter {} #[derive(Default)] pub struct OpenAiAdapterBuilder { pub(crate) api_key: Option, + pub(crate) retry: Option, } impl OpenAiAdapterBuilder { @@ -57,19 +58,34 @@ impl OpenAiAdapterBuilder { pub fn new_from_env() -> Self { Self { api_key: env::var("OPENAI_API_KEY").ok(), + retry: None, } } + /// Set a retry policy for OpenAI HTTP calls. + pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self { + self.retry = Some(retry); + self + } + /// Finalise the builder and return a ready-to-use adapter. /// /// # Errors /// /// * [`ArtificialError::Invalid`] – if the API key is missing. pub fn build(self) -> Result { + let api_key = self.api_key.ok_or(ArtificialError::Invalid( + "missing env variable: `OPENAI_API_KEY`".into(), + ))?; + + let client = if let Some(retry) = self.retry { + OpenAiClient::new(api_key).with_retry_policy(retry) + } else { + OpenAiClient::new(api_key) + }; + Ok(OpenAiAdapter { - client: Arc::new(OpenAiClient::new(self.api_key.ok_or( - ArtificialError::Invalid("missing env variable: `OPENAI_API_KEY`".into()), - )?)), + client: Arc::new(client), }) } } diff --git a/crates/artificial-openai/src/client.rs b/crates/artificial-openai/src/client.rs index eeb29cb..05372ac 100644 --- a/crates/artificial-openai/src/client.rs +++ b/crates/artificial-openai/src/client.rs @@ -10,9 +10,122 @@ use std::time::Duration; use crate::{ api_v1::{ChatCompletionChunkResponse, ChatCompletionRequest, ChatCompletionResponse}, - error::OpenAiError, + error::{OpenAiError, OpenAiRateLimitHeaders}, }; +fn parse_retry_after_seconds(headers: &reqwest::header::HeaderMap) -> Duration { + use reqwest::header::RETRY_AFTER; + if let Some(val) = headers.get(RETRY_AFTER).and_then(|hv| hv.to_str().ok()) + && let Ok(secs) = val.trim().parse::() + { + return Duration::from_secs(secs); + } + Duration::from_secs(0) +} + +fn header_u32(headers: &reqwest::header::HeaderMap, name: &str) -> Option { + headers + .get(name) + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::().ok()) +} + +fn header_string(headers: &reqwest::header::HeaderMap, name: &str) -> Option { + headers + .get(name) + .and_then(|hv| hv.to_str().ok()) + .map(|s| s.to_string()) +} + +fn extract_rate_limit_info( + headers: &reqwest::header::HeaderMap, +) -> (Option, Option, OpenAiRateLimitHeaders) { + let retry_after = { + let d = parse_retry_after_seconds(headers); + if d.as_secs() > 0 { Some(d) } else { None } + }; + + let info = OpenAiRateLimitHeaders { + limit_requests: header_u32(headers, "x-ratelimit-limit-requests"), + remaining_requests: header_u32(headers, "x-ratelimit-remaining-requests"), + reset_requests: header_string(headers, "x-ratelimit-reset-requests"), + limit_tokens: header_u32(headers, "x-ratelimit-limit-tokens"), + remaining_tokens: header_u32(headers, "x-ratelimit-remaining-tokens"), + reset_tokens: header_string(headers, "x-ratelimit-reset-tokens"), + }; + + // Prefer request reset, fall back to token reset. + let reset_at = info + .reset_requests + .clone() + .or_else(|| info.reset_tokens.clone()); + + (retry_after, reset_at, info) +} +#[cfg(feature = "tracing")] +fn log_rate_limit_tight(headers: &reqwest::header::HeaderMap, context: &str) { + let rem_reqs = header_u32(headers, "x-ratelimit-remaining-requests").unwrap_or(u32::MAX); + let rem_tokens = header_u32(headers, "x-ratelimit-remaining-tokens").unwrap_or(u32::MAX); + let lim_reqs = header_u32(headers, "x-ratelimit-limit-requests").unwrap_or(0); + let lim_tokens = header_u32(headers, "x-ratelimit-limit-tokens").unwrap_or(0); + + // Heuristics: warn when headroom is tight + let tight_reqs = rem_reqs <= 2 || (lim_reqs > 0 && rem_reqs as f32 / lim_reqs as f32 <= 0.05); + let tight_tokens = + rem_tokens <= 128 || (lim_tokens > 0 && rem_tokens as f32 / lim_tokens as f32 <= 0.05); + + if tight_reqs || tight_tokens { + tracing::warn!( + context, + remaining_requests = rem_reqs, + limit_requests = lim_reqs, + remaining_tokens = rem_tokens, + limit_tokens = lim_tokens, + "rate limit headroom is tight" + ); + } else { + tracing::debug!( + context, + remaining_requests = rem_reqs, + limit_requests = lim_reqs, + remaining_tokens = rem_tokens, + limit_tokens = lim_tokens, + "rate limit status" + ); + } +} + +#[derive(Clone, Debug)] +pub struct RetryPolicy { + pub max_retries: u32, + pub base_delay: Duration, + pub max_delay: Duration, + pub respect_retry_after: bool, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + max_retries: 3, + base_delay: Duration::from_millis(500), + max_delay: Duration::from_secs(30), + respect_retry_after: true, + } + } +} + +impl RetryPolicy { + fn backoff_for(&self, attempt: u32) -> Duration { + let pow = attempt.min(10); + let backoff = self.base_delay.saturating_mul(1 << pow); + if backoff > self.max_delay { + self.max_delay + } else { + backoff + } + } +} + const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1"; /// Minimal HTTP client for OpenAI’s *chat/completions* endpoint. @@ -26,6 +139,7 @@ pub struct OpenAiClient { api_key: String, http: HttpClient, base: String, + retry: RetryPolicy, } impl OpenAiClient { @@ -51,6 +165,123 @@ impl OpenAiClient { api_key: api_key.into(), http, base: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned()), + retry: RetryPolicy::default(), + } + } + + /// Allow callers to override the default retry policy. + pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self { + self.retry = retry; + self + } + + // Internal: send POST with retry/backoff handling. + async fn post_json_with_retry( + &self, + url: String, + headers: HeaderMap, + request: &ChatCompletionRequest, + ) -> Result { + let mut attempt: u32 = 0; + loop { + let res = self + .http + .post(url.clone()) + .headers(headers.clone()) + .json(request) + .send() + .await; + + match res { + Ok(resp) => { + let status = resp.status(); + if status.is_success() { + #[cfg(feature = "tracing")] + { + log_rate_limit_tight(resp.headers(), "success"); + } + return Ok(resp); + } + + let should_retry = status == reqwest::StatusCode::TOO_MANY_REQUESTS + || status.is_server_error(); + + if should_retry && attempt < self.retry.max_retries { + let mut delay = self.retry.backoff_for(attempt); + #[allow(unused_assignments)] + let mut hdr_delay = Duration::from_secs(0); + if self.retry.respect_retry_after { + hdr_delay = parse_retry_after_seconds(resp.headers()); + if hdr_delay > delay { + delay = hdr_delay; + } + } + #[cfg(feature = "tracing")] + { + tracing::info!( + attempt = attempt, + status = %status, + backoff_ms = delay.as_millis() as u64, + retry_after_ms = hdr_delay.as_millis() as u64, + "retrying request due to transient status" + ); + log_rate_limit_tight(resp.headers(), "retrying"); + } + // Blocking sleep to avoid introducing a new async runtime dependency. + std::thread::sleep(delay); + attempt += 1; + continue; + } else { + let status = resp.status(); + let headers_map = resp.headers().clone(); + let body = resp.text().await.unwrap_or_default(); + if status == reqwest::StatusCode::TOO_MANY_REQUESTS { + let (retry_after, reset_at, headers) = + extract_rate_limit_info(&headers_map); + #[cfg(feature = "tracing")] + { + let ra_ms = retry_after.map(|d| d.as_millis() as u64); + tracing::warn!( + status = %status, + retry_after_ms = ?ra_ms, + reset_at = ?reset_at, + "rate limited; giving up after retries" + ); + } + return Err(OpenAiError::RateLimited { + status, + body, + retry_after, + reset_at, + headers, + }); + } else { + return Err(OpenAiError::Api { status, body }); + } + } + } + Err(err) => { + // Retry on transport errors up to max_retries. + if attempt < self.retry.max_retries + && (err.is_timeout() || err.is_connect() || !err.is_status()) + { + let delay = self.retry.backoff_for(attempt); + #[cfg(feature = "tracing")] + { + tracing::info!( + attempt = attempt, + backoff_ms = delay.as_millis() as u64, + "retrying after transport error" + ); + } + std::thread::sleep(delay); + attempt += 1; + continue; + } else { + return Err(OpenAiError::Http(err)); + } + } + } } } @@ -68,19 +299,7 @@ impl OpenAiClient { ); let url = format!("{}/chat/completions", self.base); - let resp = self - .http - .post(url) - .headers(headers) - .json(&request) - .send() - .await?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(OpenAiError::Api { status, body }); - } + let resp = self.post_json_with_retry(url, headers, &request).await?; let bytes = resp.bytes().await?; let parsed: ChatCompletionResponse = serde_json::from_slice(&bytes)?; @@ -110,13 +329,7 @@ impl OpenAiClient { // 3) async stream wrapper try_stream! { - let resp = self.http.post(url).headers(headers).json(&request).send().await?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(OpenAiError::Api { status, body })?; - } + let resp = self.post_json_with_retry(url, headers, &request).await?; let mut bytes_stream = resp.bytes_stream(); let mut buf = Vec::new(); diff --git a/crates/artificial-openai/src/error.rs b/crates/artificial-openai/src/error.rs index 8b19766..70b30dc 100644 --- a/crates/artificial-openai/src/error.rs +++ b/crates/artificial-openai/src/error.rs @@ -2,6 +2,18 @@ use std::str::Utf8Error; use artificial_core::error::ArtificialError; use reqwest::StatusCode; +use std::time::Duration; + +/// Headers conveying rate limit information returned by OpenAI. +#[derive(Debug, Clone)] +pub struct OpenAiRateLimitHeaders { + pub limit_requests: Option, + pub remaining_requests: Option, + pub reset_requests: Option, + pub limit_tokens: Option, + pub remaining_tokens: Option, + pub reset_tokens: Option, +} /// High-level error type covering every failure mode the client can hit. #[derive(Debug, thiserror::Error)] @@ -12,6 +24,15 @@ pub enum OpenAiError { #[error("couldn’t serialise body: {0}")] Serde(#[from] serde_json::Error), + #[error("rate limited (status {status}), retry_after={retry_after:?}, reset_at={reset_at:?}")] + RateLimited { + status: StatusCode, + body: String, + retry_after: Option, + reset_at: Option, + headers: OpenAiRateLimitHeaders, + }, + #[error("OpenAI returned non-success status {status}: {body}")] Api { status: StatusCode, body: String }, diff --git a/crates/artificial-openai/src/lib.rs b/crates/artificial-openai/src/lib.rs index 93bcc1e..b43782d 100644 --- a/crates/artificial-openai/src/lib.rs +++ b/crates/artificial-openai/src/lib.rs @@ -7,4 +7,5 @@ mod provider_impl_prompt; pub use adapter::{OpenAiAdapter, OpenAiAdapterBuilder}; mod api_v1; mod client; +pub use client::RetryPolicy; pub mod error; diff --git a/crates/artificial-openai/src/provider_impl_chat_stream.rs b/crates/artificial-openai/src/provider_impl_chat_stream.rs index 3f2d974..a854b08 100644 --- a/crates/artificial-openai/src/provider_impl_chat_stream.rs +++ b/crates/artificial-openai/src/provider_impl_chat_stream.rs @@ -1,6 +1,7 @@ use std::pin::Pin; use crate::OpenAiAdapter; +use crate::api_v1::ChatCompletionMessage; use crate::api_v1::ChatCompletionRequest; use crate::api_v1::FinishReason; use artificial_core::error::{ArtificialError, Result}; @@ -11,6 +12,8 @@ use futures_core::stream::Stream; use std::collections::HashMap; impl StreamingChatProvider for OpenAiAdapter { + type Message = ChatCompletionMessage; + type Delta<'s> = Pin> + Send + 's>> where diff --git a/crates/artificial-prompt/Cargo.toml b/crates/artificial-prompt/Cargo.toml index ddaaa7b..3953f15 100644 --- a/crates/artificial-prompt/Cargo.toml +++ b/crates/artificial-prompt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "artificial-prompt" -version = "0.4.2" +version = "0.5.0" edition = "2024" description = "Fluent builders and helpers for composing markdown prompt fragments" license = "MIT" @@ -9,4 +9,4 @@ categories = ["development-tools", "text-processing"] keywords = ["ai", "prompt", "markdown", "builder"] [dependencies] -artificial-core = { path = "../artificial-core" , version = "0.4.2"} +artificial-core = { path = "../artificial-core" , version = "0.5.0"} diff --git a/crates/artificial-types/Cargo.toml b/crates/artificial-types/Cargo.toml index bc69f14..956f1e1 100644 --- a/crates/artificial-types/Cargo.toml +++ b/crates/artificial-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "artificial-types" -version = "0.4.2" +version = "0.5.0" edition = "2024" description = "Reusable prompt fragments and helper types for the Artificial prompt-engineering SDK" license = "MIT" @@ -9,8 +9,8 @@ categories = ["development-tools", "text-processing"] keywords = ["ai", "prompt-fragments", "json-schema", "sdk"] [dependencies] -artificial-core = { path = "../artificial-core" , version = "0.4.2"} -artificial-prompt = { path = "../artificial-prompt" , version = "0.4.2"} +artificial-core = { path = "../artificial-core" , version = "0.5.0"} +artificial-prompt = { path = "../artificial-prompt" , version = "0.5.0"} chrono = "0.4.41" schemars.workspace = true diff --git a/crates/artificial/Cargo.toml b/crates/artificial/Cargo.toml index 7f3cd2a..80a8f84 100644 --- a/crates/artificial/Cargo.toml +++ b/crates/artificial/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "artificial" -version = "0.4.2" +version = "0.5.0" edition = "2024" description = "Typed, provider-agnostic prompt-engineering SDK for Rust" authors = ["Marc Riegel "] @@ -13,12 +13,13 @@ keywords = ["ai", "openai", "prompt-engineering", "json-schema", "typed"] [features] default = ["openai"] openai = ["dep:artificial-openai"] +tracing = ["artificial-openai/tracing"] [dependencies] -artificial-types = { path = "../artificial-types", version = "0.4.2" } -artificial-openai = { path = "../artificial-openai", optional = true, version = "0.4.2" } -artificial-core = { path = "../artificial-core", version = "0.4.2" } -artificial-prompt = { path = "../artificial-prompt", version = "0.4.2" } +artificial-types = { path = "../artificial-types", version = "0.5.0" } +artificial-openai = { path = "../artificial-openai", optional = true, version = "0.5.0" } +artificial-core = { path = "../artificial-core", version = "0.5.0" } +artificial-prompt = { path = "../artificial-prompt", version = "0.5.0" } [dev-dependencies] tokio = { version = "1", features = ["full"] }