diff --git a/datafusion/functions/benches/split_part.rs b/datafusion/functions/benches/split_part.rs index 72ca6f66a00d4..0f4998effc2ac 100644 --- a/datafusion/functions/benches/split_part.rs +++ b/datafusion/functions/benches/split_part.rs @@ -18,6 +18,7 @@ use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; use datafusion_functions::string::split_part; @@ -29,15 +30,15 @@ use std::sync::Arc; const N_ROWS: usize = 8192; -/// Creates strings with `num_parts` random alphanumeric segments of `part_len` -/// bytes each, joined by `delimiter`. -fn gen_split_part_data( +/// Creates an array of strings with `num_parts` random alphanumeric segments +/// of `part_len` bytes each, joined by `delimiter`. +fn gen_string_array( n_rows: usize, num_parts: usize, part_len: usize, delimiter: &str, use_string_view: bool, -) -> (ColumnarValue, ColumnarValue) { +) -> ColumnarValue { let mut rng = StdRng::seed_from_u64(42); let mut strings: Vec = Vec::with_capacity(n_rows); @@ -54,22 +55,12 @@ fn gen_split_part_data( strings.push(parts.join(delimiter)); } - let delimiters: Vec = vec![delimiter.to_string(); n_rows]; - if use_string_view { let string_array: StringViewArray = strings.into_iter().map(Some).collect(); - let delimiter_array: StringViewArray = delimiters.into_iter().map(Some).collect(); - ( - ColumnarValue::Array(Arc::new(string_array) as ArrayRef), - ColumnarValue::Array(Arc::new(delimiter_array) as ArrayRef), - ) + ColumnarValue::Array(Arc::new(string_array) as ArrayRef) } else { let string_array: StringArray = strings.into_iter().map(Some).collect(); - let delimiter_array: StringArray = delimiters.into_iter().map(Some).collect(); - ( - ColumnarValue::Array(Arc::new(string_array) as ArrayRef), - ColumnarValue::Array(Arc::new(delimiter_array) as ArrayRef), - ) + ColumnarValue::Array(Arc::new(string_array) as ArrayRef) } } @@ -81,12 +72,10 @@ fn bench_split_part( name: &str, tag: &str, strings: ColumnarValue, - delimiters: ColumnarValue, - position: i64, + delimiter: ColumnarValue, + position: ColumnarValue, ) { - let positions: ColumnarValue = - ColumnarValue::Array(Arc::new(Int64Array::from(vec![position; N_ROWS]))); - let args = vec![strings, delimiters, positions]; + let args = vec![strings, delimiter, position]; let arg_fields: Vec<_> = args .iter() .enumerate() @@ -119,108 +108,143 @@ fn criterion_benchmark(c: &mut Criterion) { let config_options = Arc::new(ConfigOptions::default()); let mut group = c.benchmark_group("split_part"); - // Utf8, single-char delimiter, first position + // ── Scalar delimiter and position ──────────────── + + // Utf8, single-char delimiter, scalar args { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8_single_char", + "scalar_utf8_single_char", "pos_first", strings, - delimiters, - 1, + delimiter, + position, ); } - // Utf8, single-char delimiter, middle position { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(5))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8_single_char", + "scalar_utf8_single_char", "pos_middle", strings, - delimiters, - 5, + delimiter, + position, ); } - // Utf8, single-char delimiter, negative position { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8_single_char", + "scalar_utf8_single_char", "pos_negative", strings, - delimiters, - -1, + delimiter, + position, ); } - // Utf8, multi-char delimiter, middle position + // Utf8, multi-char delimiter, scalar args { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, "~@~", false); + let strings = gen_string_array(N_ROWS, 10, 8, "~@~", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some("~@~".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(5))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8_multi_char", + "scalar_utf8_multi_char", "pos_middle", strings, - delimiters, - 5, + delimiter, + position, ); } - // Utf8View, single-char delimiter, first position + // Utf8, long strings, scalar args { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", true); + let strings = gen_string_array(N_ROWS, 50, 16, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(25))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8view_single_char", - "pos_first", + "scalar_utf8_long_strings", + "pos_middle", strings, - delimiters, - 1, + delimiter, + position, ); } - // Utf8, single-char delimiter, many long parts + // Utf8View, long parts, scalar args + { + let strings = gen_string_array(N_ROWS, 10, 32, ".", true); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(5))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8view_long_parts", + "pos_middle", + strings, + delimiter, + position, + ); + } + + // ── Array delimiter and position ───────────────── + + // Utf8, single-char delimiter, array args { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 50, 16, ".", false); + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiters: StringArray = vec![Some("."); N_ROWS].into_iter().collect(); + let delimiter = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let positions = ColumnarValue::Array(Arc::new(Int64Array::from(vec![5; N_ROWS]))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8_long_strings", + "array_utf8_single_char", "pos_middle", strings, - delimiters, - 25, + delimiter, + positions, ); } - // Utf8View, single-char delimiter, middle position, long parts + // Utf8, multi-char delimiter, array args { - let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 32, ".", true); + let strings = gen_string_array(N_ROWS, 10, 8, "~@~", false); + let delimiters: StringArray = vec![Some("~@~"); N_ROWS].into_iter().collect(); + let delimiter = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let positions = ColumnarValue::Array(Arc::new(Int64Array::from(vec![5; N_ROWS]))); bench_split_part( &mut group, &split_part_func, &config_options, - "utf8view_long_parts", + "array_utf8_multi_char", "pos_middle", strings, - delimiters, - 5, + delimiter, + positions, ); } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 87beacabe8491..d29eb7d8c7483 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -17,8 +17,8 @@ use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, AsArray, GenericStringBuilder, Int64Array, StringArrayType, - StringLikeArrayBuilder, StringViewBuilder, + Array, ArrayRef, AsArray, GenericStringBuilder, Int64Array, StringArrayType, + StringLikeArrayBuilder, StringViewBuilder, new_null_array, }; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; @@ -30,6 +30,7 @@ use datafusion_expr::{ }; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; +use memchr::memmem; use std::sync::Arc; #[user_doc( @@ -101,6 +102,16 @@ impl ScalarUDFImpl for SplitPartFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; + // Fast path: array string, scalar delimiter and position. + if let ( + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(delim_scalar), + ColumnarValue::Scalar(pos_scalar), + ) = (&args[0], &args[1], &args[2]) + { + return split_part_scalar(string_array, delim_scalar, pos_scalar); + } + // First, determine if any of the arguments is an Array let len = args.iter().find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), @@ -192,7 +203,7 @@ impl ScalarUDFImpl for SplitPartFunc { } } -/// Finds the nth split part of `string` by `delimiter`. +/// Finds the `n`th (0-based) split part of `string` by `delimiter`. #[inline] fn split_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> { if delimiter.len() == 1 { @@ -206,7 +217,7 @@ fn split_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> } } -/// Like `split_nth` but splits from the right. +/// Like `split_nth` but splits from the right (`n` is 0-based from the end). #[inline] fn rsplit_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> { if delimiter.len() == 1 { @@ -220,6 +231,190 @@ fn rsplit_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> } } +/// Fast path for `split_part(array, scalar_delimiter, scalar_position)`. +fn split_part_scalar( + string_array: &ArrayRef, + delim_scalar: &ScalarValue, + pos_scalar: &ScalarValue, +) -> Result { + let delimiter = delim_scalar.try_as_str().ok_or_else(|| { + exec_datafusion_err!( + "Unsupported delimiter type {:?} for split_part", + delim_scalar.data_type() + ) + })?; + + let position = match pos_scalar { + ScalarValue::Int64(v) => *v, + other => { + return exec_err!( + "Unsupported position type {:?} for split_part", + other.data_type() + ); + } + }; + + if position == Some(0) { + return exec_err!("field position must not be zero"); + } + + // Null delimiter or position → every row is null. + let (Some(delimiter), Some(position)) = (delimiter, position) else { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + string_array.len(), + ))); + }; + + let result = match string_array.data_type() { + DataType::Utf8View => split_part_scalar_impl( + string_array.as_string_view(), + delimiter, + position, + StringViewBuilder::with_capacity(string_array.len()), + ), + DataType::Utf8 => { + let arr = string_array.as_string::(); + split_part_scalar_impl( + arr, + delimiter, + position, + GenericStringBuilder::::with_capacity( + arr.len(), + arr.value_data().len(), + ), + ) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + split_part_scalar_impl( + arr, + delimiter, + position, + GenericStringBuilder::::with_capacity( + arr.len(), + arr.value_data().len(), + ), + ) + } + other => exec_err!("Unsupported string type {other:?} for split_part"), + }?; + + Ok(ColumnarValue::Array(result)) +} + +/// Inner implementation for the scalar-delimiter, scalar-position fast path. +/// Constructing a `memmem::Finder` is somewhat expensive but it's a win when +/// done once and amortized over the entire batch. +fn split_part_scalar_impl<'a, S, B>( + string_array: S, + delimiter: &str, + position: i64, + builder: B, +) -> Result +where + S: StringArrayType<'a> + Copy, + B: StringLikeArrayBuilder, +{ + if delimiter.is_empty() { + // PostgreSQL: empty delimiter treats input as a single field, + // so only position 1 or -1 returns the input string. + return if position == 1 || position == -1 { + map_strings(string_array, builder, Some) + } else { + map_strings(string_array, builder, |_| None) + }; + } + + let delim_bytes = delimiter.as_bytes(); + let delim_len = delimiter.len(); + + if position > 0 { + let idx: usize = (position - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {position} exceeds maximum supported value" + ) + })?; + let finder = memmem::Finder::new(delim_bytes); + map_strings(string_array, builder, |s| { + split_nth_finder(s, &finder, delim_len, idx) + }) + } else { + let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {position} exceeds minimum supported value" + ) + })?; + let finder_rev = memmem::FinderRev::new(delim_bytes); + map_strings(string_array, builder, |s| { + rsplit_nth_finder(s, &finder_rev, delim_len, idx) + }) + } +} + +/// Applies `f` to each non-null string in `string_array`, appending the +/// result (or `""` when `f` returns `None`) to `builder`. +#[inline] +fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result +where + S: StringArrayType<'a> + Copy, + B: StringLikeArrayBuilder, + F: Fn(&'a str) -> Option<&'a str>, +{ + for string in string_array.iter() { + match string { + Some(s) => builder.append_value(f(s).unwrap_or("")), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +/// Finds the `n`th (0-based) split part using a pre-built `memmem::Finder`. +#[inline] +fn split_nth_finder<'a>( + string: &'a str, + finder: &memmem::Finder, + delim_len: usize, + n: usize, +) -> Option<&'a str> { + let bytes = string.as_bytes(); + let mut start = 0; + for _ in 0..n { + match finder.find(&bytes[start..]) { + Some(pos) => start += pos + delim_len, + None => return None, + } + } + match finder.find(&bytes[start..]) { + Some(pos) => Some(&string[start..start + pos]), + None => Some(&string[start..]), + } +} + +/// Like `split_nth_finder` but splits from the right (`n` is 0-based from +/// the end). +#[inline] +fn rsplit_nth_finder<'a>( + string: &'a str, + finder: &memmem::FinderRev, + delim_len: usize, + n: usize, +) -> Option<&'a str> { + let bytes = string.as_bytes(); + let mut end = bytes.len(); + for _ in 0..n { + match finder.rfind(&bytes[..end]) { + Some(pos) => end = pos, + None => return None, + } + } + match finder.rfind(&bytes[..end]) { + Some(pos) => Some(&string[pos + delim_len..end]), + None => Some(&string[..end]), + } +} + fn split_part_impl<'a, StringArrType, DelimiterArrType, B>( string_array: &StringArrType, delimiter_array: &DelimiterArrType, diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index a6341bc686f74..5ba0da13ff8fd 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -724,6 +724,72 @@ SELECT split_part('a,b', '', -2) statement error DataFusion error: Execution error: field position must not be zero SELECT split_part('abc~@~def~@~ghi', '~@~', 0) +# Position 0 with column input errors even for empty/null inputs +statement error DataFusion error: Execution error: field position must not be zero +SELECT split_part(column1, '.', 0) FROM (VALUES (NULL::text)) AS t(column1) + +# split_part with column input (exercises the scalar-delimiter fast path) +query TTT +SELECT + split_part(column1, '.', 1), + split_part(column1, '.', 2), + split_part(column1, '.', 3) +FROM (VALUES ('a.b.c'), ('d.e.f'), ('x.y')) AS t(column1) +---- +a b c +d e f +x y (empty) + +# Multi-char delimiter with column input +query TT +SELECT + split_part(column1, '~@~', 2), + split_part(column1, '~@~', 3) +FROM (VALUES ('abc~@~def~@~ghi'), ('one~@~two')) AS t(column1) +---- +def ghi +two (empty) + +# Negative position with column input +query TT +SELECT + split_part(column1, '.', -1), + split_part(column1, '.', -2) +FROM (VALUES ('a.b.c'), ('x.y')) AS t(column1) +---- +c b +y x + +# Empty delimiter with column input +query TT +SELECT + split_part(column1, '', 1), + split_part(column1, '', 2) +FROM (VALUES ('abc'), ('xyz')) AS t(column1) +---- +abc (empty) +xyz (empty) + +# NULL column values with scalar delimiter +query T +SELECT split_part(column1, '.', 2) +FROM (VALUES ('a.b'), (NULL), ('c.d')) AS t(column1) +---- +b +NULL +d + +# Utf8View column with scalar delimiter +query TT +SELECT + split_part(column1, '.', 1), + split_part(column1, '.', 2) +FROM (SELECT arrow_cast(column1, 'Utf8View') AS column1 + FROM (VALUES ('a.b.c'), ('x.y.z')) AS t(column1)) +---- +a b +x y + query B SELECT starts_with('alphabet', 'alph') ----