-
Notifications
You must be signed in to change notification settings - Fork 0
20305: perf: Optimize translate() UDF for scalar inputs #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7a2b9f2
d693af3
60d896d
d5c184e
a721f43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -35,8 +35,8 @@ use datafusion_macros::user_doc; | |||||||||||||||||
|
|
||||||||||||||||||
| #[user_doc( | ||||||||||||||||||
| doc_section(label = "String Functions"), | ||||||||||||||||||
| description = "Translates characters in a string to specified translation characters.", | ||||||||||||||||||
| syntax_example = "translate(str, chars, translation)", | ||||||||||||||||||
| description = "Performs character-wise substitution based on a mapping.", | ||||||||||||||||||
| syntax_example = "translate(str, from, to)", | ||||||||||||||||||
| sql_example = r#"```sql | ||||||||||||||||||
| > select translate('twice', 'wic', 'her'); | ||||||||||||||||||
| +--------------------------------------------------+ | ||||||||||||||||||
|
|
@@ -46,10 +46,10 @@ use datafusion_macros::user_doc; | |||||||||||||||||
| +--------------------------------------------------+ | ||||||||||||||||||
| ```"#, | ||||||||||||||||||
| standard_argument(name = "str", prefix = "String"), | ||||||||||||||||||
| argument(name = "chars", description = "Characters to translate."), | ||||||||||||||||||
| argument(name = "from", description = "The characters to be replaced."), | ||||||||||||||||||
| argument( | ||||||||||||||||||
| name = "translation", | ||||||||||||||||||
| description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string." | ||||||||||||||||||
| name = "to", | ||||||||||||||||||
| description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed." | ||||||||||||||||||
| ) | ||||||||||||||||||
| )] | ||||||||||||||||||
| #[derive(Debug, PartialEq, Eq, Hash)] | ||||||||||||||||||
|
|
@@ -71,6 +71,7 @@ impl TranslateFunc { | |||||||||||||||||
| vec![ | ||||||||||||||||||
| Exact(vec![Utf8View, Utf8, Utf8]), | ||||||||||||||||||
| Exact(vec![Utf8, Utf8, Utf8]), | ||||||||||||||||||
| Exact(vec![LargeUtf8, Utf8, Utf8]), | ||||||||||||||||||
| ], | ||||||||||||||||||
| Volatility::Immutable, | ||||||||||||||||||
| ), | ||||||||||||||||||
|
|
@@ -99,6 +100,65 @@ impl ScalarUDFImpl for TranslateFunc { | |||||||||||||||||
| &self, | ||||||||||||||||||
| args: datafusion_expr::ScalarFunctionArgs, | ||||||||||||||||||
| ) -> Result<ColumnarValue> { | ||||||||||||||||||
| // When from and to are scalars, pre-build the translation map once | ||||||||||||||||||
| if let (Some(from_str), Some(to_str)) = ( | ||||||||||||||||||
| try_as_scalar_str(&args.args[1]), | ||||||||||||||||||
| try_as_scalar_str(&args.args[2]), | ||||||||||||||||||
| ) { | ||||||||||||||||||
| let from_graphemes: Vec<&str> = from_str.graphemes(true).collect(); | ||||||||||||||||||
| let to_graphemes: Vec<&str> = to_str.graphemes(true).collect(); | ||||||||||||||||||
|
|
||||||||||||||||||
| let mut from_map: HashMap<&str, usize> = HashMap::new(); | ||||||||||||||||||
| for (index, c) in from_graphemes.iter().enumerate() { | ||||||||||||||||||
| // Ignore characters that already exist in from_map | ||||||||||||||||||
| from_map.entry(*c).or_insert(index); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| let ascii_table = build_ascii_translate_table(from_str, to_str); | ||||||||||||||||||
|
|
||||||||||||||||||
| let string_array = match &args.args[0] { | ||||||||||||||||||
| ColumnarValue::Array(arr) => Arc::clone(arr), | ||||||||||||||||||
| ColumnarValue::Scalar(s) => s.to_array_of_size(args.number_rows)?, | ||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
| let result = match string_array.data_type() { | ||||||||||||||||||
| DataType::Utf8View => { | ||||||||||||||||||
| let arr = string_array.as_string_view(); | ||||||||||||||||||
| translate_with_map::<i32, _>( | ||||||||||||||||||
| arr, | ||||||||||||||||||
| &from_map, | ||||||||||||||||||
| &to_graphemes, | ||||||||||||||||||
| ascii_table.as_ref(), | ||||||||||||||||||
| ) | ||||||||||||||||||
| } | ||||||||||||||||||
| DataType::Utf8 => { | ||||||||||||||||||
| let arr = string_array.as_string::<i32>(); | ||||||||||||||||||
| translate_with_map::<i32, _>( | ||||||||||||||||||
| arr, | ||||||||||||||||||
| &from_map, | ||||||||||||||||||
| &to_graphemes, | ||||||||||||||||||
| ascii_table.as_ref(), | ||||||||||||||||||
| ) | ||||||||||||||||||
| } | ||||||||||||||||||
| DataType::LargeUtf8 => { | ||||||||||||||||||
| let arr = string_array.as_string::<i64>(); | ||||||||||||||||||
| translate_with_map::<i64, _>( | ||||||||||||||||||
| arr, | ||||||||||||||||||
| &from_map, | ||||||||||||||||||
| &to_graphemes, | ||||||||||||||||||
| ascii_table.as_ref(), | ||||||||||||||||||
| ) | ||||||||||||||||||
| } | ||||||||||||||||||
| other => { | ||||||||||||||||||
| return exec_err!( | ||||||||||||||||||
| "Unsupported data type {other:?} for function translate" | ||||||||||||||||||
| ); | ||||||||||||||||||
| } | ||||||||||||||||||
| }?; | ||||||||||||||||||
|
|
||||||||||||||||||
| return Ok(ColumnarValue::Array(result)); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| make_scalar_function(invoke_translate, vec![])(&args.args) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -107,6 +167,14 @@ impl ScalarUDFImpl for TranslateFunc { | |||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /// If `cv` is a non-null scalar string, return its value. | ||||||||||||||||||
| fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> { | ||||||||||||||||||
| match cv { | ||||||||||||||||||
| ColumnarValue::Scalar(s) => s.try_as_str().flatten(), | ||||||||||||||||||
| _ => None, | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> { | ||||||||||||||||||
| match args[0].data_type() { | ||||||||||||||||||
| DataType::Utf8View => { | ||||||||||||||||||
|
|
@@ -170,7 +238,7 @@ where | |||||||||||||||||
| // Build from_map using reusable buffer | ||||||||||||||||||
| from_graphemes.extend(from.graphemes(true)); | ||||||||||||||||||
| for (index, c) in from_graphemes.iter().enumerate() { | ||||||||||||||||||
| // Ignore characters that already exist in from_map, else insert | ||||||||||||||||||
| // Ignore characters that already exist in from_map | ||||||||||||||||||
| from_map.entry(*c).or_insert(index); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -199,6 +267,99 @@ where | |||||||||||||||||
| Ok(Arc::new(result) as ArrayRef) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /// Sentinel value in the ASCII translate table indicating the character should | ||||||||||||||||||
| /// be deleted (the `from` character has no corresponding `to` character). Any | ||||||||||||||||||
| /// value > 127 works since valid ASCII is 0–127. | ||||||||||||||||||
| const ASCII_DELETE: u8 = 0xFF; | ||||||||||||||||||
|
|
||||||||||||||||||
| /// If `from` and `to` are both ASCII, build a fixed-size lookup table for | ||||||||||||||||||
| /// translation. Each entry maps an input byte to its replacement byte, or to | ||||||||||||||||||
| /// [`ASCII_DELETE`] if the character should be removed. Returns `None` if | ||||||||||||||||||
| /// either string contains non-ASCII characters. | ||||||||||||||||||
| fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { | ||||||||||||||||||
| if !from.is_ascii() || !to.is_ascii() { | ||||||||||||||||||
| return None; | ||||||||||||||||||
| } | ||||||||||||||||||
| let mut table = [0u8; 128]; | ||||||||||||||||||
| for i in 0..128u8 { | ||||||||||||||||||
| table[i as usize] = i; | ||||||||||||||||||
| } | ||||||||||||||||||
| let to_bytes = to.as_bytes(); | ||||||||||||||||||
| let mut seen = [false; 128]; | ||||||||||||||||||
| for (i, from_byte) in from.bytes().enumerate() { | ||||||||||||||||||
| let idx = from_byte as usize; | ||||||||||||||||||
| if !seen[idx] { | ||||||||||||||||||
| seen[idx] = true; | ||||||||||||||||||
| if i < to_bytes.len() { | ||||||||||||||||||
| table[idx] = to_bytes[i]; | ||||||||||||||||||
| } else { | ||||||||||||||||||
| table[idx] = ASCII_DELETE; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| Some(table) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /// Optimized translate for constant `from` and `to` arguments: uses a pre-built | ||||||||||||||||||
| /// translation map instead of rebuilding it for every row. When an ASCII byte | ||||||||||||||||||
| /// lookup table is provided, ASCII input rows the lookup table; non-ASCII | ||||||||||||||||||
| /// inputs fallback to using the map. | ||||||||||||||||||
|
Comment on lines
+303
to
+306
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo in doc comment. Line 305: "ASCII input rows the lookup table" → "ASCII input rows use the lookup table". 📝 Fix-/// lookup table is provided, ASCII input rows the lookup table; non-ASCII
+/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:good-to-have; category:typo; feedback: The CodeRabbit AI reviewer is correct! The is a typo in the documentation that should be fixed. |
||||||||||||||||||
| fn translate_with_map<'a, T: OffsetSizeTrait, V>( | ||||||||||||||||||
| string_array: V, | ||||||||||||||||||
| from_map: &HashMap<&str, usize>, | ||||||||||||||||||
| to_graphemes: &[&str], | ||||||||||||||||||
| ascii_table: Option<&[u8; 128]>, | ||||||||||||||||||
| ) -> Result<ArrayRef> | ||||||||||||||||||
| where | ||||||||||||||||||
| V: ArrayAccessor<Item = &'a str>, | ||||||||||||||||||
| { | ||||||||||||||||||
| let mut string_graphemes: Vec<&str> = Vec::new(); | ||||||||||||||||||
| let mut result_graphemes: Vec<&str> = Vec::new(); | ||||||||||||||||||
| let mut ascii_buf: Vec<u8> = Vec::new(); | ||||||||||||||||||
|
|
||||||||||||||||||
| let result = ArrayIter::new(string_array) | ||||||||||||||||||
| .map(|string| { | ||||||||||||||||||
| string.map(|s| { | ||||||||||||||||||
| // Fast path: byte-level table lookup for ASCII strings | ||||||||||||||||||
| if let Some(table) = ascii_table | ||||||||||||||||||
| && s.is_ascii() | ||||||||||||||||||
| { | ||||||||||||||||||
| ascii_buf.clear(); | ||||||||||||||||||
| for &b in s.as_bytes() { | ||||||||||||||||||
| let mapped = table[b as usize]; | ||||||||||||||||||
| if mapped != ASCII_DELETE { | ||||||||||||||||||
| ascii_buf.push(mapped); | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| // ascii_buf contains only ASCII bytes, so it is valid | ||||||||||||||||||
| // UTF-8. | ||||||||||||||||||
| return String::from_utf8(ascii_buf.clone()).unwrap(); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // Slow path: grapheme-based translation | ||||||||||||||||||
| string_graphemes.clear(); | ||||||||||||||||||
| result_graphemes.clear(); | ||||||||||||||||||
|
|
||||||||||||||||||
| string_graphemes.extend(s.graphemes(true)); | ||||||||||||||||||
| for c in &string_graphemes { | ||||||||||||||||||
| match from_map.get(*c) { | ||||||||||||||||||
| Some(n) => { | ||||||||||||||||||
| if let Some(replacement) = to_graphemes.get(*n) { | ||||||||||||||||||
| result_graphemes.push(*replacement); | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| None => result_graphemes.push(*c), | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
Comment on lines
+340
to
+353
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The You would also need to remove the declaration of result_graphemes.clear();
for c in s.graphemes(true) {
match from_map.get(c) {
Some(n) => {
if let Some(replacement) = to_graphemes.get(*n) {
result_graphemes.push(*replacement);
}
}
None => result_graphemes.push(c),
}
}
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback: The Gemini AI reviewer is correct! There is no need to store the graphemes in a local variable that makes an unnecessary allocation. Just iterating over the graphemes is enough for the logic needs. Prevents a heap allocation that is not really necessary. |
||||||||||||||||||
|
|
||||||||||||||||||
| result_graphemes.concat() | ||||||||||||||||||
| }) | ||||||||||||||||||
| }) | ||||||||||||||||||
| .collect::<GenericStringArray<T>>(); | ||||||||||||||||||
|
|
||||||||||||||||||
| Ok(Arc::new(result) as ArrayRef) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| #[cfg(test)] | ||||||||||||||||||
| mod tests { | ||||||||||||||||||
| use arrow::array::{Array, StringArray}; | ||||||||||||||||||
|
|
@@ -284,6 +445,21 @@ mod tests { | |||||||||||||||||
| Utf8, | ||||||||||||||||||
| StringArray | ||||||||||||||||||
| ); | ||||||||||||||||||
| // Non-ASCII input with ASCII scalar from/to: exercises the | ||||||||||||||||||
| // grapheme fallback within translate_with_map. | ||||||||||||||||||
| test_function!( | ||||||||||||||||||
| TranslateFunc::new(), | ||||||||||||||||||
| vec![ | ||||||||||||||||||
| ColumnarValue::Scalar(ScalarValue::from("café")), | ||||||||||||||||||
| ColumnarValue::Scalar(ScalarValue::from("ae")), | ||||||||||||||||||
| ColumnarValue::Scalar(ScalarValue::from("AE")) | ||||||||||||||||||
| ], | ||||||||||||||||||
| Ok(Some("cAfé")), | ||||||||||||||||||
| &str, | ||||||||||||||||||
| Utf8, | ||||||||||||||||||
| StringArray | ||||||||||||||||||
| ); | ||||||||||||||||||
|
|
||||||||||||||||||
| #[cfg(not(feature = "unicode_expressions"))] | ||||||||||||||||||
| test_function!( | ||||||||||||||||||
| TranslateFunc::new(), | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new signature arm
Exact(vec![LargeUtf8, Utf8, Utf8])looks inconsistent withinvoke_translate’sDataType::LargeUtf8branch, which downcastsfrom/toviaas_string::<i64>(). Iffrom/toare arrays (as allowed by the signature), this will likely panic due to offset-size mismatch (Utf8/i32 vs LargeUtf8/i64).Severity: high
Other Locations
datafusion/functions/src/unicode/translate.rs:194datafusion/functions/src/unicode/translate.rs:195🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback: The Augment AI reviewer is correct! There is no need to use LargeUtf8 for the from and to arguments, since there is no alphabet that long (u64::MAX). u32::MAX is more than enough for any alphabet. There is a bug in the as_string::() calls though! They should use as_string::() since this is Utf8 type.