Skip to content

Commit de6da50

Browse files
committed
fix(agents): validate UTF-8 boundaries in mention parsing
1 parent d201070 commit de6da50

File tree

1 file changed

+142
-6
lines changed

1 file changed

+142
-6
lines changed

src/cortex-agents/src/mention.rs

Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,46 @@
1717
use regex::Regex;
1818
use std::sync::LazyLock;
1919

20+
/// Safely get the string slice up to the given byte position.
21+
///
22+
/// Returns the slice `&text[..pos]` if `pos` is at a valid UTF-8 character boundary.
23+
/// If `pos` is inside a multi-byte character, finds the nearest valid boundary
24+
/// by searching backwards.
25+
fn safe_slice_up_to(text: &str, pos: usize) -> &str {
26+
if pos >= text.len() {
27+
return text;
28+
}
29+
if text.is_char_boundary(pos) {
30+
return &text[..pos];
31+
}
32+
// Find the nearest valid boundary by searching backwards
33+
let mut valid_pos = pos;
34+
while valid_pos > 0 && !text.is_char_boundary(valid_pos) {
35+
valid_pos -= 1;
36+
}
37+
&text[..valid_pos]
38+
}
39+
40+
/// Safely get the string slice from the given byte position to the end.
41+
///
42+
/// Returns the slice `&text[pos..]` if `pos` is at a valid UTF-8 character boundary.
43+
/// If `pos` is inside a multi-byte character, finds the nearest valid boundary
44+
/// by searching forwards.
45+
fn safe_slice_from(text: &str, pos: usize) -> &str {
46+
if pos >= text.len() {
47+
return "";
48+
}
49+
if text.is_char_boundary(pos) {
50+
return &text[pos..];
51+
}
52+
// Find the nearest valid boundary by searching forwards
53+
let mut valid_pos = pos;
54+
while valid_pos < text.len() && !text.is_char_boundary(valid_pos) {
55+
valid_pos += 1;
56+
}
57+
&text[valid_pos..]
58+
}
59+
2060
/// A parsed agent mention from user input.
2161
#[derive(Debug, Clone, PartialEq, Eq)]
2262
pub struct AgentMention {
@@ -108,10 +148,10 @@ pub fn extract_mention_and_text(
108148
) -> Option<(AgentMention, String)> {
109149
let mention = find_first_valid_mention(text, valid_agents)?;
110150

111-
// Remove the mention from text
151+
// Remove the mention from text, using safe slicing for UTF-8 boundaries
112152
let mut remaining = String::with_capacity(text.len());
113-
remaining.push_str(&text[..mention.start]);
114-
remaining.push_str(&text[mention.end..]);
153+
remaining.push_str(safe_slice_up_to(text, mention.start));
154+
remaining.push_str(safe_slice_from(text, mention.end));
115155

116156
// Trim and normalize whitespace
117157
let remaining = remaining.trim().to_string();
@@ -123,7 +163,8 @@ pub fn extract_mention_and_text(
123163
pub fn starts_with_mention(text: &str, valid_agents: &[&str]) -> bool {
124164
let text = text.trim();
125165
if let Some(mention) = find_first_valid_mention(text, valid_agents) {
126-
mention.start == 0 || text[..mention.start].trim().is_empty()
166+
// Use safe slicing to handle UTF-8 boundaries
167+
mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty()
127168
} else {
128169
false
129170
}
@@ -196,8 +237,8 @@ pub fn parse_message_for_agent(text: &str, valid_agents: &[&str]) -> ParsedAgent
196237

197238
// Check if message starts with @agent
198239
if let Some((mention, remaining)) = extract_mention_and_text(text, valid_agents) {
199-
// Only trigger if mention is at the start
200-
if mention.start == 0 || text[..mention.start].trim().is_empty() {
240+
// Only trigger if mention is at the start, using safe slicing for UTF-8 boundaries
241+
if mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() {
201242
return ParsedAgentMessage::for_agent(mention.agent_name, remaining, text.to_string());
202243
}
203244
}
@@ -318,4 +359,99 @@ mod tests {
318359
assert_eq!(mentions[0].agent_name, "my-agent");
319360
assert_eq!(mentions[1].agent_name, "my_agent");
320361
}
362+
363+
// UTF-8 boundary safety tests
364+
#[test]
365+
fn test_safe_slice_up_to_ascii() {
366+
let text = "hello world";
367+
assert_eq!(safe_slice_up_to(text, 5), "hello");
368+
assert_eq!(safe_slice_up_to(text, 0), "");
369+
assert_eq!(safe_slice_up_to(text, 100), "hello world");
370+
}
371+
372+
#[test]
373+
fn test_safe_slice_up_to_multibyte() {
374+
// "こんにちは" - each character is 3 bytes
375+
let text = "こんにちは";
376+
assert_eq!(safe_slice_up_to(text, 3), "こ"); // Valid boundary
377+
assert_eq!(safe_slice_up_to(text, 6), "こん"); // Valid boundary
378+
// Position 4 is inside the second character, should return "こ"
379+
assert_eq!(safe_slice_up_to(text, 4), "こ");
380+
assert_eq!(safe_slice_up_to(text, 5), "こ");
381+
}
382+
383+
#[test]
384+
fn test_safe_slice_from_multibyte() {
385+
let text = "こんにちは";
386+
assert_eq!(safe_slice_from(text, 3), "んにちは"); // Valid boundary
387+
// Position 4 is inside second character, should skip to position 6
388+
assert_eq!(safe_slice_from(text, 4), "にちは");
389+
assert_eq!(safe_slice_from(text, 5), "にちは");
390+
}
391+
392+
#[test]
393+
fn test_extract_mention_with_multibyte_prefix() {
394+
let valid = vec!["general"];
395+
396+
// Multi-byte characters before mention
397+
let result = extract_mention_and_text("日本語 @general search files", &valid);
398+
assert!(result.is_some());
399+
let (mention, remaining) = result.unwrap();
400+
assert_eq!(mention.agent_name, "general");
401+
// The prefix should be preserved without panicking
402+
assert!(remaining.contains("search files"));
403+
}
404+
405+
#[test]
406+
fn test_starts_with_mention_multibyte() {
407+
let valid = vec!["general"];
408+
409+
// Whitespace with multi-byte characters should not cause panic
410+
assert!(starts_with_mention(" @general task", &valid));
411+
412+
// Multi-byte characters before mention - should return false, not panic
413+
assert!(!starts_with_mention("日本語 @general task", &valid));
414+
}
415+
416+
#[test]
417+
fn test_parse_message_for_agent_multibyte() {
418+
let valid = vec!["general"];
419+
420+
// Multi-byte prefix - should not panic
421+
let parsed = parse_message_for_agent("日本語 @general find files", &valid);
422+
// Since mention is not at the start, should not invoke task
423+
assert!(!parsed.should_invoke_task);
424+
425+
// Multi-byte in the prompt (after mention)
426+
let parsed = parse_message_for_agent("@general 日本語を検索", &valid);
427+
assert!(parsed.should_invoke_task);
428+
assert_eq!(parsed.agent, Some("general".to_string()));
429+
assert_eq!(parsed.prompt, "日本語を検索");
430+
}
431+
432+
#[test]
433+
fn test_extract_mention_with_emoji() {
434+
let valid = vec!["general"];
435+
436+
// Emojis are 4 bytes each
437+
let result = extract_mention_and_text("🎉 @general celebrate", &valid);
438+
assert!(result.is_some());
439+
let (mention, remaining) = result.unwrap();
440+
assert_eq!(mention.agent_name, "general");
441+
assert!(remaining.contains("celebrate"));
442+
}
443+
444+
#[test]
445+
fn test_mixed_multibyte_and_ascii() {
446+
let valid = vec!["general"];
447+
448+
// Mix of ASCII, CJK, and emoji
449+
let text = "Hello 世界 🌍 @general search for 日本語";
450+
let result = extract_mention_and_text(text, &valid);
451+
assert!(result.is_some());
452+
let (mention, remaining) = result.unwrap();
453+
assert_eq!(mention.agent_name, "general");
454+
// Should not panic and produce valid output
455+
assert!(!remaining.is_empty());
456+
}
321457
}

0 commit comments

Comments
 (0)