diff --git a/CHANGELOG.md b/CHANGELOG.md index ebf3c8d..e453f73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. ## [0.10.9] - 2025-10-26 +### Fixed +- Tightened Turnkey rate-limit heuristics to require explicit phrases while + preserving stack-backed searches for short ASCII patterns, preventing matches + on unrelated words such as "corporate". + +### Tests +- Added regression coverage to ensure corporate network outages and operational + failure rates classify as network/service issues rather than rate limits. + ### Changed - Raised the documented MSRV to Rust 1.90 to match the `rust-version` requirement. diff --git a/src/turnkey/classifier.rs b/src/turnkey/classifier.rs index 4a74c47..97e06d2 100644 --- a/src/turnkey/classifier.rs +++ b/src/turnkey/classifier.rs @@ -38,14 +38,27 @@ pub fn classify_turnkey_error(msg: &str) -> TurnkeyErrorKind { "duplicate", "unique" ]; - const RL_PATTERNS: &[&str] = &["429", "rate", "throttle"]; + const RL_PATTERNS: &[&str] = &["429", "throttle", "throttled", "too many requests"]; + const RL_BOUNDARY_PATTERNS: &[&str] = &[ + "rate limit", + "rate limited", + "rate limiting", + "rate-limit", + "rate-limited", + "rate-limiting", + "ratelimit", + "ratelimited", + "ratelimiting" + ]; const TO_PATTERNS: &[&str] = &["timeout", "timed out", "deadline exceeded"]; const AUTH_PATTERNS: &[&str] = &["401", "403", "unauthor", "forbidden"]; const NET_PATTERNS: &[&str] = &["network", "connection", "connect", "dns", "tls", "socket"]; if contains_any_nocase(msg, UNIQUE_PATTERNS) { TurnkeyErrorKind::UniqueLabel - } else if contains_any_nocase(msg, RL_PATTERNS) { + } else if contains_any_nocase(msg, RL_PATTERNS) + || contains_any_nocase_with_boundaries(msg, RL_BOUNDARY_PATTERNS) + { TurnkeyErrorKind::RateLimited } else if contains_any_nocase(msg, TO_PATTERNS) { TurnkeyErrorKind::Timeout @@ -65,42 +78,116 @@ pub fn classify_turnkey_error(msg: &str) -> TurnkeyErrorKind { /// allocate once to store their lowercased representation. #[inline] fn contains_nocase(haystack: &str, needle: &str) -> bool { - // Fast path: empty needle always matches. + contains_nocase_with(haystack, needle, |_, _, _| true) +} + +/// Check whether `haystack` contains any of the `needles` (ASCII +/// case-insensitive). +#[inline] +fn contains_any_nocase(haystack: &str, needles: &[&str]) -> bool { + needles.iter().any(|n| contains_nocase(haystack, n)) +} + +#[inline] +fn contains_nocase_with_boundaries(haystack: &str, needle: &str) -> bool { + contains_nocase_with(haystack, needle, |start, end, haystack_bytes| { + let prev_ok = if start == 0 { + true + } else { + !is_ascii_alphanumeric(haystack_bytes[start - 1]) + }; + let next_ok = if end >= haystack_bytes.len() { + true + } else { + !is_ascii_alphanumeric(haystack_bytes[end]) + }; + prev_ok && next_ok + }) +} + +#[inline] +fn contains_any_nocase_with_boundaries(haystack: &str, needles: &[&str]) -> bool { + needles + .iter() + .any(|n| contains_nocase_with_boundaries(haystack, n)) +} + +#[inline] +fn contains_nocase_with( + haystack: &str, + needle: &str, + mut boundary: impl FnMut(usize, usize, &[u8]) -> bool +) -> bool { if needle.is_empty() { return true; } let haystack_bytes = haystack.as_bytes(); let needle_bytes = needle.as_bytes(); + let lowered = LowercasedNeedle::new(needle_bytes); + let needle_lower = lowered.as_slice(); - let search = |needle_lower: &[u8]| { - haystack_bytes.windows(needle_lower.len()).any(|window| { + if needle_lower.is_empty() { + return true; + } + + haystack_bytes + .windows(needle_lower.len()) + .enumerate() + .any(|(start, window)| { window .iter() .zip(needle_lower.iter()) - .all(|(hay, lower_needle)| ascii_lower(*hay) == *lower_needle) + .all(|(hay, lower)| ascii_lower(*hay) == *lower) + && boundary(start, start + needle_lower.len(), haystack_bytes) }) - }; +} - if needle_bytes.len() <= STACK_NEEDLE_INLINE_CAP { - let mut inline = [0u8; STACK_NEEDLE_INLINE_CAP]; - for (idx, byte) in needle_bytes.iter().enumerate() { - inline[idx] = ascii_lower(*byte); +struct LowercasedNeedle { + inline: [u8; STACK_NEEDLE_INLINE_CAP], + len: usize, + heap: Option> +} + +impl LowercasedNeedle { + #[inline] + fn new(needle_bytes: &[u8]) -> Self { + if needle_bytes.len() <= STACK_NEEDLE_INLINE_CAP { + let mut inline = [0u8; STACK_NEEDLE_INLINE_CAP]; + for (idx, byte) in needle_bytes.iter().enumerate() { + inline[idx] = ascii_lower(*byte); + } + Self { + inline, + len: needle_bytes.len(), + heap: None + } + } else { + let mut heap = Vec::with_capacity(needle_bytes.len()); + for byte in needle_bytes { + heap.push(ascii_lower(*byte)); + } + Self { + inline: [0u8; STACK_NEEDLE_INLINE_CAP], + len: needle_bytes.len(), + heap: Some(heap) + } } - search(&inline[..needle_bytes.len()]) - } else { - let mut lowercased = Vec::with_capacity(needle_bytes.len()); - for byte in needle_bytes { - lowercased.push(ascii_lower(*byte)); + } + + #[inline] + fn as_slice(&self) -> &[u8] { + match &self.heap { + Some(heap) => heap.as_slice(), + None => &self.inline[..self.len] } - search(lowercased.as_slice()) } } -/// Check whether `haystack` contains any of the `needles` (ASCII -/// case-insensitive). #[inline] -fn contains_any_nocase(haystack: &str, needles: &[&str]) -> bool { - needles.iter().any(|n| contains_nocase(haystack, n)) +const fn is_ascii_alphanumeric(byte: u8) -> bool { + (byte >= b'0' && byte <= b'9') + || (byte >= b'A' && byte <= b'Z') + || (byte >= b'a' && byte <= b'z') } /// Converts ASCII letters to lowercase and leaves other bytes unchanged. @@ -128,4 +215,25 @@ pub(super) mod internal_tests { let needle = "a".repeat(128); assert!(contains_nocase(&haystack, &needle)); } + + #[test] + fn contains_nocase_with_boundaries_respects_word_edges() { + assert!(contains_nocase_with_boundaries( + "rate limited", + "rate limited" + )); + assert!(contains_nocase_with_boundaries( + "429 rate-limit reached", + "rate-limit" + )); + assert!(contains_nocase_with_boundaries( + "api ratelimited", + "ratelimited" + )); + assert!(!contains_nocase_with_boundaries("corporate policy", "rate")); + assert!(!contains_nocase_with_boundaries( + "accelerate limit", + "rate limit" + )); + } } diff --git a/src/turnkey/tests.rs b/src/turnkey/tests.rs index 64bfa70..c279b3e 100644 --- a/src/turnkey/tests.rs +++ b/src/turnkey/tests.rs @@ -49,7 +49,10 @@ fn classifier_rate_limited() { for s in [ "429 Too Many Requests", "rate limit exceeded", - "throttled by upstream" + "throttled by upstream", + "client ratelimited", + "rate-limited by upstream", + "rate limiting in effect" ] { assert!( matches!(classify_turnkey_error(s), TurnkeyErrorKind::RateLimited), @@ -89,7 +92,8 @@ fn classifier_network() { "connection reset", "DNS failure", "TLS handshake", - "socket hang up" + "socket hang up", + "Corporate network outage" ] { assert!( matches!(classify_turnkey_error(s), TurnkeyErrorKind::Network), @@ -100,10 +104,12 @@ fn classifier_network() { #[test] fn classifier_service_fallback() { - assert!(matches!( - classify_turnkey_error("unrecognized issue"), - TurnkeyErrorKind::Service - )); + for s in ["unrecognized issue", "operational failure rate"] { + assert!( + matches!(classify_turnkey_error(s), TurnkeyErrorKind::Service), + "failed on: {s}" + ); + } } #[test]