From 6cf94c73d613e93e9a6e2b9efa65de6cca568f78 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 9 Apr 2026 09:22:42 -0400 Subject: [PATCH 1/4] Add more regexp_replace test coverage (#21485) ## Which issue does this PR close? - related to https://github.com/apache/datafusion/pull/21379 ## Rationale for this change While reviewing https://github.com/apache/datafusion/pull/21379 I noticed there was minimal Utf8View coverage of the related code. ## What changes are included in this PR? Update the regexp_replace tests to cover utf8, largeutf8, utf8view and dictionary ## Are these changes tested? Yes only tests I verified these tests also pass when run on - https://github.com/apache/datafusion/pull/21379 ## Are there any user-facing changes? No --- .../test_files/regexp/regexp_replace.slt | 145 +++++++++++++----- 1 file changed, 108 insertions(+), 37 deletions(-) diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt index 99e0b10430186..e27ff1e9c1a00 100644 --- a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt +++ b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt @@ -128,43 +128,6 @@ from (values ('a'), ('b')) as tbl(col); NULL NULL NULL NULL NULL NULL -# Extract domain from URL using anchored pattern with trailing .* -# This tests that the full URL suffix is replaced, not just the matched prefix -query T -SELECT regexp_replace(url, '^https?://(?:www\.)?([^/]+)/.*$', '\1') FROM (VALUES - ('https://www.example.com/path/to/page?q=1'), - ('http://test.org/foo/bar'), - ('https://example.com/'), - ('not-a-url') -) AS t(url); ----- -example.com -test.org -example.com -not-a-url - -# More than one capture group should disable the short-regex fast path. -# This still uses replacement \1, but captures_len() will be > 2, so the -# implementation must fall back to the normal regexp_replace path. -query T -SELECT regexp_replace(url, '^https?://((www\.)?([^/]+))/.*$', '\1') FROM (VALUES - ('https://www.example.com/path/to/page?q=1'), - ('http://test.org/foo/bar'), - ('not-a-url') -) AS t(url); ----- -www.example.com -test.org -not-a-url - -# If the overall pattern matches but capture group 1 does not participate, -# regexp_replace(..., '\1') should substitute the empty string, not keep -# the original input. -query B -SELECT regexp_replace('bzzz', '^(a)?b.*$', '\1') = ''; ----- -true - # Stripping trailing .*$ must not change match semantics for inputs with # newlines when the original pattern does not use the 's' flag. query B @@ -183,3 +146,111 @@ SELECT regexp_replace( ) = concat('x', chr(10), 'rest'); ---- true + + +# Fixture for testing optimizations in regexp_replace +statement ok +CREATE TABLE regexp_replace_optimized_cases ( + value string, + regexp string, + replacement string, + expected string +); + +# Extract domain from URL using anchored pattern with trailing .* +# This tests that the full URL suffix is replaced, not just the matched prefix. +statement ok +INSERT INTO regexp_replace_optimized_cases VALUES + ('https://www.example.com/path/to/page?q=1', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'example.com'), + ('http://test.org/foo/bar', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'test.org'), + ('https://example.com/', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'example.com'), + ('not-a-url', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'not-a-url'); + +# More than one capture group should disable the short-regex fast path. +# This still uses replacement \1, but captures_len() will be > 2, so the +# implementation must fall back to the normal regexp_replace path. +statement ok +INSERT INTO regexp_replace_optimized_cases VALUES + ('https://www.example.com/path/to/page?q=1', '^https?://((www\.)?([^/]+))/.*$', '\1', 'www.example.com'), + ('http://test.org/foo/bar', '^https?://((www\.)?([^/]+))/.*$', '\1', 'test.org'), + ('not-a-url', '^https?://((www\.)?([^/]+))/.*$', '\1', 'not-a-url'); + +# If the overall pattern matches but capture group 1 does not participate, +# regexp_replace(..., '\1') should substitute the empty string, not keep +# the original input. +statement ok +INSERT INTO regexp_replace_optimized_cases VALUES + ('bzzz', '^(a)?b.*$', '\1', ''); + + +query TB +SELECT value, regexp_replace(value, regexp, replacement) = expected +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +query TB +SELECT value, regexp_replace( + arrow_cast(value, 'LargeUtf8'), + arrow_cast(regexp, 'LargeUtf8'), + arrow_cast(replacement, 'LargeUtf8') + ) = arrow_cast(expected, 'LargeUtf8') +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +query TB +SELECT value, regexp_replace( + arrow_cast(value, 'Utf8View'), + arrow_cast(regexp, 'Utf8View'), + arrow_cast(replacement, 'Utf8View') + ) = arrow_cast(expected, 'Utf8View') +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +query TB +SELECT value, regexp_replace( + arrow_cast(value, 'Dictionary(Int32, Utf8)'), + arrow_cast(regexp, 'Dictionary(Int32, Utf8)'), + arrow_cast(replacement, 'Dictionary(Int32, Utf8)') + ) = arrow_cast(expected, 'Dictionary(Int32, Utf8)') +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +# cleanup +statement ok +DROP TABLE regexp_replace_optimized_cases; From 6c106ba4e3f6ce797f983a70224c5aefd2a48831 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 9 Apr 2026 09:24:49 -0400 Subject: [PATCH 2/4] fix: Use codepoints in `lpad`, `rpad`, `translate` (#21405) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Which issue does this PR close? - Closes #21060. ## Rationale for this change `lpad`, `rpad`, and `translate` use grapheme segmentation. This is inconsistent with how these functions behave in Postgres and DuckDB, as well as the SQL standard -- segmentation based on Unicode codepoints is used instead. It also happens that grapheme-based segmentation is significantly more expensive than codepoint-based segmentation. In the case of `lpad` and `rpad`, graphemes and codepoints were used inconsistently: the input string was measured in code points but the fill string was measured in graphemes. #3054 switched to using codepoints for most string-related functions in DataFusion but these three functions still need to be changed. Benchmarks (M4 Max): lpad size=1024: - lpad utf8 [str_len=5, target=20]: 12.4 µs → 12.8 µs, +3.0% - lpad stringview [str_len=5, target=20]: 11.5 µs → 11.7 µs, +1.4% - lpad utf8 [str_len=20, target=50]: 11.3 µs → 11.3 µs, +0.1% - lpad stringview [str_len=20, target=50]: 11.8 µs → 12.0 µs, +1.6% - lpad utf8 unicode [target=20]: 98.4 µs → 24.4 µs, -75.1% - lpad stringview unicode [target=20]: 99.8 µs → 26.0 µs, -74.0% - lpad utf8 scalar [str_len=5, target=20, fill='x']: 8.7 µs → 8.8 µs, +1.0% - lpad stringview scalar [str_len=5, target=20, fill='x']: 10.2 µs → 10.1 µs, -0.1% - lpad utf8 scalar unicode [str_len=5, target=20, fill='é']: 44.7 µs → 10.9 µs, -75.7% - lpad utf8 scalar truncate [str_len=20, target=5, fill='é']: 152.5 µs → 11.7 µs, -92.3% lpad size=4096: - lpad utf8 [str_len=5, target=20]: 55.9 µs → 55.1 µs, -1.4% - lpad stringview [str_len=5, target=20]: 49.2 µs → 50.1 µs, +1.8% - lpad utf8 [str_len=20, target=50]: 46.6 µs → 46.4 µs, -0.5% - lpad stringview [str_len=20, target=50]: 47.5 µs → 48.5 µs, +2.1% - lpad utf8 unicode [target=20]: 401.3 µs → 100.1 µs, -75.0% - lpad stringview unicode [target=20]: 397.7 µs → 104.9 µs, -73.6% - lpad utf8 scalar [str_len=5, target=20, fill='x']: 34.2 µs → 35.0 µs, +2.4% - lpad stringview scalar [str_len=5, target=20, fill='x']: 40.1 µs → 40.4 µs, +0.6% - lpad utf8 scalar unicode [str_len=5, target=20, fill='é']: 178.3 µs → 42.9 µs, -76.0% - lpad utf8 scalar truncate [str_len=20, target=5, fill='é']: 601.3 µs → 46.2 µs, -92.3% rpad size=1024: - rpad utf8 [str_len=5, target=20]: 15.5 µs → 14.4 µs, -7.1% - rpad stringview [str_len=5, target=20]: 13.8 µs → 14.0 µs, +1.7% - rpad utf8 [str_len=20, target=50]: 12.6 µs → 12.7 µs, +1.3% - rpad stringview [str_len=20, target=50]: 13.0 µs → 13.1 µs, +0.7% - rpad utf8 unicode [target=20]: 103.5 µs → 26.0 µs, -74.8% - rpad stringview unicode [target=20]: 101.2 µs → 27.6 µs, -72.7% - rpad utf8 scalar [str_len=5, target=20, fill='x']: 11.4 µs → 10.9 µs, -3.9% - rpad stringview scalar [str_len=5, target=20, fill='x']: 12.2 µs → 12.6 µs, +2.8% - rpad utf8 scalar unicode [str_len=5, target=20, fill='é']: 46.3 µs → 12.4 µs, -73.1% - rpad utf8 scalar truncate [str_len=20, target=5, fill='é']: 155.6 µs → 11.6 µs, -92.4% rpad size=4096: - rpad utf8 [str_len=5, target=20]: 70.1 µs → 61.6 µs, -12.2% - rpad stringview [str_len=5, target=20]: 60.4 µs → 56.8 µs, -6.0% - rpad utf8 [str_len=20, target=50]: 50.6 µs → 51.2 µs, +1.2% - rpad stringview [str_len=20, target=50]: 53.7 µs → 53.3 µs, -0.8% - rpad utf8 unicode [target=20]: 407.1 µs → 104.0 µs, -74.5% - rpad stringview unicode [target=20]: 404.8 µs → 114.5 µs, -71.7% - rpad utf8 scalar [str_len=5, target=20, fill='x']: 47.5 µs → 45.6 µs, -4.0% - rpad stringview scalar [str_len=5, target=20, fill='x']: 56.4 µs → 58.5 µs, +3.6% - rpad utf8 scalar unicode [str_len=5, target=20, fill='é']: 184.1 µs → 48.1 µs, -73.9% - rpad utf8 scalar truncate [str_len=20, target=5, fill='é']: 606.4 µs → 45.6 µs, -92.5% translate size=1024: - array_from_to [str_len=8]: 140.0 µs → 37.6 µs, -73.2% - scalar_from_to [str_len=8]: 9.0 µs → 8.8 µs, -2.7% - array_from_to [str_len=32]: 371.3 µs → 65.6 µs, -82.3% - scalar_from_to [str_len=32]: 19.9 µs → 19.2 µs, -3.6% - array_from_to [str_len=128]: 1249.6 µs → 188.7 µs, -84.9% - scalar_from_to [str_len=128]: 70.2 µs → 64.7 µs, -7.9% - array_from_to [str_len=1024]: 9349.4 µs → 1378.1 µs, -85.3% - scalar_from_to [str_len=1024]: 506.5 µs → 445.8 µs, -12.0% translate size=4096: - array_from_to [str_len=8]: 548.0 µs → 147.1 µs, -73.2% - scalar_from_to [str_len=8]: 33.9 µs → 32.8 µs, -3.1% - array_from_to [str_len=32]: 1457.2 µs → 266.0 µs, -81.7% - scalar_from_to [str_len=32]: 78.0 µs → 75.5 µs, -3.2% - array_from_to [str_len=128]: 4935.0 µs → 791.1 µs, -84.0% - scalar_from_to [str_len=128]: 278.2 µs → 260.7 µs, -6.3% - array_from_to [str_len=1024]: 37496 µs → 5591 µs, -85.1% - scalar_from_to [str_len=1024]: 2058.0 µs → 1770 µs, -14.0% ## What changes are included in this PR? * Switch from grapheme segmentation to codepoint segmentation for `lpad`, `rpad`, and `translate` * Add SLT tests * Refactor a few helper functions * Remove dependency on `unicode_segmentation` crate as it is no longer used ## Are these changes tested? Yes. The new SLT tests were also run against DuckDB and Postgres to confirm the behavior is consistent. ## Are there any user-facing changes? Yes. This PR changes the behavior of `lpad`, `rpad`, and `translate`, although the new behavior is more consistent with the SQL standard and with other SQL implementations. --- Cargo.lock | 1 - datafusion/functions/Cargo.toml | 3 +- datafusion/functions/src/unicode/common.rs | 41 ++++++- datafusion/functions/src/unicode/lpad.rs | 88 ++++++-------- datafusion/functions/src/unicode/rpad.rs | 90 +++++++-------- datafusion/functions/src/unicode/translate.rs | 108 ++++++++---------- .../test_files/string/string_literal.slt | 81 +++++++++++++ .../library-user-guide/upgrading/54.0.0.md | 12 ++ 8 files changed, 255 insertions(+), 169 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 983a74ed3cf03..87c18826096c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2259,7 +2259,6 @@ dependencies = [ "regex", "sha2", "tokio", - "unicode-segmentation", "uuid", ] diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 7503c337517ef..02b8e842280bf 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -59,7 +59,7 @@ regex_expressions = ["regex"] # enable string functions string_expressions = ["uuid"] # enable unicode functions -unicode_expressions = ["unicode-segmentation"] +unicode_expressions = [] [lib] name = "datafusion_functions" @@ -87,7 +87,6 @@ num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } sha2 = { workspace = true, optional = true } -unicode-segmentation = { version = "^1.13.2", optional = true } uuid = { workspace = true, features = ["v4"], optional = true } [dev-dependencies] diff --git a/datafusion/functions/src/unicode/common.rs b/datafusion/functions/src/unicode/common.rs index 002776e6c6538..0158325e98a19 100644 --- a/datafusion/functions/src/unicode/common.rs +++ b/datafusion/functions/src/unicode/common.rs @@ -78,6 +78,39 @@ impl LeftRightSlicer for RightSlicer { } } +/// Returns the byte offset of the `n`th codepoint in `string`, +/// or `string.len()` if the string has fewer than `n` codepoints. +#[inline] +pub(crate) fn byte_offset_of_char(string: &str, n: usize) -> usize { + string + .char_indices() + .nth(n) + .map_or(string.len(), |(i, _)| i) +} + +/// If `string` has more than `n` codepoints, returns the byte offset of +/// the `n`-th codepoint boundary. Otherwise returns the total codepoint count. +#[inline] +pub(crate) fn char_count_or_boundary(string: &str, n: usize) -> StringCharLen { + let mut count = 0; + for (byte_idx, _) in string.char_indices() { + if count == n { + return StringCharLen::ByteOffset(byte_idx); + } + count += 1; + } + StringCharLen::CharCount(count) +} + +/// Result of [`char_count_or_boundary`]. +pub(crate) enum StringCharLen { + /// The string has more than `n` codepoints; contains the byte offset + /// at the `n`-th codepoint boundary. + ByteOffset(usize), + /// The string has `n` or fewer codepoints; contains the exact count. + CharCount(usize), +} + /// Calculate the byte length of the substring of `n` chars from string `string` #[inline] fn left_right_byte_length(string: &str, n: i64) -> usize { @@ -88,11 +121,9 @@ fn left_right_byte_length(string: &str, n: i64) -> usize { .map(|(index, _)| index) .unwrap_or(0), Ordering::Equal => 0, - Ordering::Greater => string - .char_indices() - .nth(n.unsigned_abs().min(usize::MAX as u64) as usize) - .map(|(index, _)| index) - .unwrap_or(string.len()), + Ordering::Greater => { + byte_offset_of_char(string, n.unsigned_abs().min(usize::MAX as u64) as usize) + } } } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index d7487c385e84f..d27bc8633e730 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -24,7 +24,6 @@ use arrow::array::{ OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; -use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; @@ -178,7 +177,9 @@ impl ScalarUDFImpl for LPadFunc { } } -use super::common::{try_as_scalar_i64, try_as_scalar_str}; +use super::common::{ + StringCharLen, char_count_or_boundary, try_as_scalar_i64, try_as_scalar_str, +}; /// Optimized lpad for constant target_len and fill arguments. fn lpad_scalar_args<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( @@ -270,27 +271,22 @@ fn lpad_scalar_unicode<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( let data_capacity = string_array.len().saturating_mul(target_len * 4); let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); - let mut graphemes_buf = Vec::new(); for maybe_string in string_array.iter() { match maybe_string { - Some(string) => { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars.is_empty() { - builder.append_value(string); - } else { - let pad_chars = target_len - graphemes_buf.len(); - let pad_bytes = char_byte_offsets[pad_chars]; - builder.write_str(&padding_buf[..pad_bytes])?; + Some(string) => match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + if !fill_chars.is_empty() { + let pad_chars = target_len - char_count; + let pad_bytes = char_byte_offsets[pad_chars]; + builder.write_str(&padding_buf[..pad_bytes])?; + } builder.append_value(string); } - } + }, None => builder.append_null(), } } @@ -378,7 +374,6 @@ where { let array = if let Some(fill_array) = fill_array { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); let mut fill_chars_buf = Vec::new(); for ((string, target_len), fill) in string_array @@ -407,8 +402,7 @@ where } if string.is_ascii() && fill.is_ascii() { - // ASCII fast path: byte length == character length, - // so we skip expensive grapheme segmentation. + // ASCII fast path: byte length == character length. let str_len = string.len(); if target_len < str_len { builder.append_value(&string[..target_len]); @@ -428,26 +422,24 @@ where builder.append_value(string); } } else { - // Reuse buffers by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - fill_chars_buf.clear(); fill_chars_buf.extend(fill.chars()); - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars_buf.is_empty() { - builder.append_value(string); - } else { - for l in 0..target_len - graphemes_buf.len() { - let c = - *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap(); - builder.write_char(c)?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + if !fill_chars_buf.is_empty() { + for l in 0..target_len - char_count { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; + } + } + builder.append_value(string); } - builder.append_value(string); } } } else { @@ -458,7 +450,6 @@ where builder.finish() } else { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); for (string, target_len) in string_array.iter().zip(length_array.iter()) { if let (Some(string), Some(target_len)) = (string, target_len) { @@ -491,19 +482,16 @@ where builder.append_value(string); } } else { - // Reuse buffer by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else { - for _ in 0..(target_len - graphemes_buf.len()) { - builder.write_str(" ")?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + for _ in 0..(target_len - char_count) { + builder.write_str(" ")?; + } + builder.append_value(string); } - builder.append_value(string); } } } else { diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 44ce4640422d6..b3e14f93526ab 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -24,7 +24,6 @@ use arrow::array::{ OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; -use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; @@ -178,7 +177,9 @@ impl ScalarUDFImpl for RPadFunc { } } -use super::common::{try_as_scalar_i64, try_as_scalar_str}; +use super::common::{ + StringCharLen, char_count_or_boundary, try_as_scalar_i64, try_as_scalar_str, +}; /// Optimized rpad for constant target_len and fill arguments. fn rpad_scalar_args<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( @@ -271,28 +272,23 @@ fn rpad_scalar_unicode<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( let data_capacity = string_array.len().saturating_mul(target_len * 4); let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); - let mut graphemes_buf = Vec::new(); for maybe_string in string_array.iter() { match maybe_string { - Some(string) => { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars.is_empty() { - builder.append_value(string); - } else { - let pad_chars = target_len - graphemes_buf.len(); - let pad_bytes = char_byte_offsets[pad_chars]; + Some(string) => match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { builder.write_str(string)?; - builder.write_str(&padding_buf[..pad_bytes])?; + if !fill_chars.is_empty() { + let pad_chars = target_len - char_count; + let pad_bytes = char_byte_offsets[pad_chars]; + builder.write_str(&padding_buf[..pad_bytes])?; + } builder.append_value(""); } - } + }, None => builder.append_null(), } } @@ -377,7 +373,6 @@ where { let array = if let Some(fill_array) = fill_array { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); let mut fill_chars_buf = Vec::new(); for ((string, target_len), fill) in string_array @@ -406,8 +401,7 @@ where } if string.is_ascii() && fill.is_ascii() { - // ASCII fast path: byte length == character length, - // so we skip expensive grapheme segmentation. + // ASCII fast path: byte length == character length. let str_len = string.len(); if target_len < str_len { builder.append_value(&string[..target_len]); @@ -428,26 +422,25 @@ where builder.append_value(""); } } else { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - fill_chars_buf.clear(); fill_chars_buf.extend(fill.chars()); - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars_buf.is_empty() { - builder.append_value(string); - } else { - builder.write_str(string)?; - for l in 0..target_len - graphemes_buf.len() { - let c = - *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap(); - builder.write_char(c)?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + if !fill_chars_buf.is_empty() { + for l in 0..target_len - char_count { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; + } + } + builder.append_value(""); } - builder.append_value(""); } } } else { @@ -458,7 +451,6 @@ where builder.finish() } else { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); for (string, target_len) in string_array.iter().zip(length_array.iter()) { if let (Some(string), Some(target_len)) = (string, target_len) { @@ -492,19 +484,17 @@ where builder.append_value(""); } } else { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else { - builder.write_str(string)?; - for _ in 0..(target_len - graphemes_buf.len()) { - builder.write_str(" ")?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + for _ in 0..(target_len - char_count) { + builder.write_str(" ")?; + } + builder.append_value(""); } - builder.append_value(""); } } } else { diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 5f95c095a644e..29dc660b86f62 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -21,7 +21,6 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use datafusion_common::HashMap; -use unicode_segmentation::UnicodeSegmentation; use crate::utils::make_scalar_function; use datafusion_common::{Result, exec_err}; @@ -97,11 +96,10 @@ impl ScalarUDFImpl for TranslateFunc { try_as_scalar_str(&args.args[1]), try_as_scalar_str(&args.args[2]), ) { - let to_graphemes: Vec<&str> = to_str.graphemes(true).collect(); + let to_chars: Vec = to_str.chars().collect(); - let mut from_map: HashMap<&str, usize> = HashMap::new(); - for (index, c) in from_str.graphemes(true).enumerate() { - // Ignore characters that already exist in from_map + let mut from_map: HashMap = HashMap::new(); + for (index, c) in from_str.chars().enumerate() { from_map.entry(c).or_insert(index); } @@ -117,7 +115,7 @@ impl ScalarUDFImpl for TranslateFunc { translate_with_map( arr, &from_map, - &to_graphemes, + &to_chars, ascii_table.as_ref(), builder, ) @@ -129,7 +127,7 @@ impl ScalarUDFImpl for TranslateFunc { translate_with_map( arr, &from_map, - &to_graphemes, + &to_chars, ascii_table.as_ref(), builder, ) @@ -141,7 +139,7 @@ impl ScalarUDFImpl for TranslateFunc { translate_with_map( arr, &from_map, - &to_graphemes, + &to_chars, ascii_table.as_ref(), builder, ) @@ -215,48 +213,27 @@ where let from_array_iter = ArrayIter::new(from_array); let to_array_iter = ArrayIter::new(to_array); - // Reusable buffers to avoid allocating for each row - let mut from_map: HashMap<&str, usize> = HashMap::new(); - let mut from_graphemes: Vec<&str> = Vec::new(); - let mut to_graphemes: Vec<&str> = Vec::new(); - let mut string_graphemes: Vec<&str> = Vec::new(); - let mut result_graphemes: Vec<&str> = Vec::new(); + let mut from_map: HashMap = HashMap::new(); + let mut to_chars: Vec = Vec::new(); + let mut result_buf = String::new(); for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter) { match (string, from, to) { (Some(string), Some(from), Some(to)) => { - // Clear and reuse buffers from_map.clear(); - from_graphemes.clear(); - to_graphemes.clear(); - string_graphemes.clear(); - result_graphemes.clear(); - - // Build from_map using reusable buffer - from_graphemes.extend(from.graphemes(true)); - for (index, c) in from_graphemes.iter().enumerate() { - // Ignore characters that already exist in from_map - from_map.entry(*c).or_insert(index); - } + to_chars.clear(); + result_buf.clear(); - // Build to_graphemes - to_graphemes.extend(to.graphemes(true)); - - // Process string and build result - string_graphemes.extend(string.graphemes(true)); - for c in &string_graphemes { - match from_map.get(*c) { - Some(n) => { - if let Some(replacement) = to_graphemes.get(*n) { - result_graphemes.push(*replacement); - } - } - None => result_graphemes.push(*c), - } + for (index, c) in from.chars().enumerate() { + from_map.entry(c).or_insert(index); } - builder.append_value(&result_graphemes.concat()); + to_chars.extend(to.chars()); + + translate_char_by_char(string, &from_map, &to_chars, &mut result_buf); + + builder.append_value(&result_buf); } _ => builder.append_null(), } @@ -265,6 +242,27 @@ where Ok(builder.finish()) } +/// Translate `input` character-by-character using `from_map` and `to_chars`, +/// appending the result to `buf`. +#[inline] +fn translate_char_by_char( + input: &str, + from_map: &HashMap, + to_chars: &[char], + buf: &mut String, +) { + for c in input.chars() { + match from_map.get(&c) { + Some(n) => { + if let Some(&replacement) = to_chars.get(*n) { + buf.push(replacement); + } + } + None => buf.push(c), + } + } +} + /// Sentinel value in the ASCII translate table indicating the character should /// be deleted (the `from` character has no corresponding `to` character). Any /// value > 127 works since valid ASCII is 0–127. @@ -301,11 +299,11 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { /// Optimized translate for constant `from` and `to` arguments: uses a pre-built /// translation map instead of rebuilding it for every row. When an ASCII byte /// lookup table is provided, ASCII input rows use the lookup table; non-ASCII -/// inputs fallback to using the map. +/// inputs fall back to the char-based map. fn translate_with_map<'a, V, O>( string_array: V, - from_map: &HashMap<&str, usize>, - to_graphemes: &[&str], + from_map: &HashMap, + to_chars: &[char], ascii_table: Option<&[u8; 128]>, mut builder: O, ) -> Result @@ -313,7 +311,7 @@ where V: ArrayAccessor, O: StringLikeArrayBuilder, { - let mut result_graphemes: Vec<&str> = Vec::new(); + let mut result_buf = String::new(); let mut ascii_buf: Vec = Vec::new(); for string in ArrayIter::new(string_array) { @@ -335,21 +333,9 @@ where std::str::from_utf8_unchecked(&ascii_buf) }); } else { - // Slow path: grapheme-based translation - result_graphemes.clear(); - - for c in s.graphemes(true) { - match from_map.get(c) { - Some(n) => { - if let Some(replacement) = to_graphemes.get(*n) { - result_graphemes.push(*replacement); - } - } - None => result_graphemes.push(c), - } - } - - builder.append_value(&result_graphemes.concat()); + result_buf.clear(); + translate_char_by_char(s, from_map, to_chars, &mut result_buf); + builder.append_value(&result_buf); } } None => builder.append_null(), @@ -445,7 +431,7 @@ mod tests { StringArray ); // Non-ASCII input with ASCII scalar from/to: exercises the - // grapheme fallback within translate_with_map. + // char-based fallback within translate_with_map. test_function!( TranslateFunc::new(), vec![ diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index d4fe8ee178719..97f2a40c13fea 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -312,6 +312,35 @@ SELECT lpad(NULL, 5, 'xy') ---- NULL +# lpad counts Unicode codepoints, not grapheme clusters. +# chr(769) is U+0301 COMBINING ACUTE ACCENT — 'e' || chr(769) is 2 codepoints +# but renders as a single grapheme cluster. + +# Input with combining character: 'e' + combining accent + 'x' = 3 codepoints. +# Padding to 4 means 1 space prepended. +query BII +SELECT lpad('e' || chr(769) || 'x', 4) = ' ' || 'e' || chr(769) || 'x', + character_length('e' || chr(769) || 'x'), + character_length(lpad('e' || chr(769) || 'x', 4)) +---- +true 3 4 + +# Truncating input with combining character: 'e' + combining accent + 'x' + 'y' +# = 4 codepoints. Truncating to 3 keeps first 3 codepoints: 'e' + combining accent + 'x'. +query BI +SELECT lpad('e' || chr(769) || 'xy', 3) = 'e' || chr(769) || 'x', + character_length(lpad('e' || chr(769) || 'xy', 3)) +---- +true 3 + +# Fill string with combining character: fill is 'e' + combining accent = 2 codepoints. +# Padding 'x' (1 codepoint) to length 5 means 4 fill codepoints = 2 cycles of fill. +query BI +SELECT lpad('x', 5, 'e' || chr(769)) = 'e' || chr(769) || 'e' || chr(769) || 'x', + character_length(lpad('x', 5, 'e' || chr(769))) +---- +true 5 + query T SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ---- @@ -583,6 +612,34 @@ SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') ---- NULL +# rpad counts Unicode codepoints, not grapheme clusters. +# chr(769) is U+0301 COMBINING ACUTE ACCENT. + +# Input with combining character: 'e' + combining accent + 'x' = 3 codepoints. +# Padding to 4 means 1 space appended. +query BII +SELECT rpad('e' || chr(769) || 'x', 4) = 'e' || chr(769) || 'x' || ' ', + character_length('e' || chr(769) || 'x'), + character_length(rpad('e' || chr(769) || 'x', 4)) +---- +true 3 4 + +# Truncating input with combining character: 'e' + combining accent + 'x' + 'y' +# = 4 codepoints. Truncating to 3 keeps first 3 codepoints: 'e' + combining accent + 'x'. +query BI +SELECT rpad('e' || chr(769) || 'xy', 3) = 'e' || chr(769) || 'x', + character_length(rpad('e' || chr(769) || 'xy', 3)) +---- +true 3 + +# Fill string with combining character: fill is 'e' + combining accent = 2 codepoints. +# Padding 'x' (1 codepoint) to length 5 means 4 fill codepoints = 2 cycles of fill. +query BI +SELECT rpad('x', 5, 'e' || chr(769)) = 'x' || 'e' || chr(769) || 'e' || chr(769), + character_length(rpad('x', 5, 'e' || chr(769))) +---- +true 5 + query I SELECT char_length('') ---- @@ -1829,3 +1886,27 @@ query T SELECT arrow_typeof(translate(arrow_cast('12345', 'Utf8View'), '143', 'ax')) ---- Utf8View + +# translate operates on Unicode codepoints, not grapheme clusters. +# chr(769) is U+0301 COMBINING ACUTE ACCENT. + +# Replacing a combining accent (a single codepoint) with another character. +# 'e' || chr(769) is 2 codepoints; translating chr(769) → 'X' replaces just the accent. +query B +SELECT translate('e' || chr(769), chr(769), 'X') = 'eX' +---- +true + +# Replacing the base character but not the combining accent. +query B +SELECT translate('e' || chr(769) || 'y', 'e', 'a') = 'a' || chr(769) || 'y' +---- +true + +# Deleting a combining accent (from longer than to). +# 'e' || chr(769) || 'x' with chr(769) in `from` but no corresponding `to` entry → deleted. +query BI +SELECT translate('e' || chr(769) || 'x', chr(769), '') = 'ex', + character_length(translate('e' || chr(769) || 'x', chr(769), '')) +---- +true 2 diff --git a/docs/source/library-user-guide/upgrading/54.0.0.md b/docs/source/library-user-guide/upgrading/54.0.0.md index 4e6178345bcce..c5d03ebf8878c 100644 --- a/docs/source/library-user-guide/upgrading/54.0.0.md +++ b/docs/source/library-user-guide/upgrading/54.0.0.md @@ -355,4 +355,16 @@ which results in a few changes: `timestamp-*` is UTC timezone-aware, while `local-timestamp-*` is timezone-naive. +### `lpad`, `rpad`, and `translate` now operate on Unicode codepoints instead of grapheme clusters + +Previously, `lpad`, `rpad`, and `translate` used Unicode grapheme cluster +segmentation to measure and manipulate strings. They now use Unicode codepoints, +which is consistent with the SQL standard and most other SQL implementations. It +also matches the behavior of other string-related functions in DataFusion. + +The difference is only observable for strings containing combining characters +(e.g., U+0301 COMBINING ACUTE ACCENT) or other multi-codepoint grapheme +clusters (e.g., ZWJ emoji sequences). For ASCII and most common Unicode text, +behavior is unchanged. + [#17861]: https://github.com/apache/datafusion/pull/17861 From 6a770aa49de6ef3746c72838d2a044998e9535e0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 9 Apr 2026 08:26:33 -0500 Subject: [PATCH 3/4] feat: add cast_to_type UDF for type-based casting (#21322) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Which issue does this PR close? N/A — new feature ## Rationale for this change DuckDB provides a [`cast_to_type(expression, reference)`](https://duckdb.org/docs/current/sql/expressions/cast#cast_to_type-function) function that casts the first argument to the data type of the second argument. This is useful in macros and generic SQL where types need to be preserved or matched dynamically. This PR adds the equivalent function to DataFusion, along with a fallible `try_cast_to_type` variant. ## What changes are included in this PR? - New `cast_to_type` scalar UDF in `datafusion/functions/src/core/cast_to_type.rs` - Takes two arguments: the expression to cast, and a reference expression whose **type** (not value) determines the target cast type - Uses `return_field_from_args` to infer return type from the second argument's data type - `simplify()` rewrites to `Expr::Cast` (or no-op if types match), so there is zero runtime overhead - New `try_cast_to_type` scalar UDF in `datafusion/functions/src/core/try_cast_to_type.rs` - Same as `cast_to_type` but returns NULL on cast failure instead of erroring - `simplify()` rewrites to `Expr::TryCast` - Output is always nullable - Registration of both functions in `datafusion/functions/src/core/mod.rs` ## Are these changes tested? Yes. New sqllogictest file `cast_to_type.slt` covering both functions: - Basic casts (string→int, string→double, int→string, int→double) - NULL handling - Same-type no-op - CASE expression as first argument - Arithmetic expression as first argument - Nested calls - Subquery as second argument - Column references as second argument - Boolean and date casts - Error on invalid cast (`cast_to_type`) vs NULL on invalid cast (`try_cast_to_type`) - Cross-column type matching ## Are there any user-facing changes? Two new SQL functions: - `cast_to_type(expression, reference)` — casts expression to the type of reference - `try_cast_to_type(expression, reference)` — same, but returns NULL on failure 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Martin Grigorov --- datafusion/functions/src/core/arrow_cast.rs | 21 +- .../functions/src/core/arrow_try_cast.rs | 16 +- datafusion/functions/src/core/cast_to_type.rs | 146 ++++++++ datafusion/functions/src/core/mod.rs | 14 + .../functions/src/core/try_cast_to_type.rs | 130 +++++++ .../sqllogictest/test_files/cast_to_type.slt | 347 ++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 59 +++ 7 files changed, 711 insertions(+), 22 deletions(-) create mode 100644 datafusion/functions/src/core/cast_to_type.rs create mode 100644 datafusion/functions/src/core/try_cast_to_type.rs create mode 100644 datafusion/sqllogictest/test_files/cast_to_type.slt diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index b05296721655e..0b67883c17c87 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -154,23 +154,20 @@ impl ScalarUDFImpl for ArrowCastFunc { fn simplify( &self, - mut args: Vec, + args: Vec, info: &SimplifyContext, ) -> Result { // convert this into a real cast - let target_type = data_type_from_args(self.name(), &args)?; - // remove second (type) argument - args.pop().unwrap(); - let arg = args.pop().unwrap(); - - let source_type = info.get_data_type(&arg)?; + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = data_type_from_type_arg(self.name(), &type_arg)?; + let source_type = info.get_data_type(&source_arg)?; let new_expr = if source_type == target_type { // the argument's data type is already the correct type - arg + source_arg } else { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), + expr: Box::new(source_arg), field: target_type.into_nullable_field_ref(), }) }; @@ -183,10 +180,8 @@ impl ScalarUDFImpl for ArrowCastFunc { } } -/// Returns the requested type from the arguments -pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result { - let [_, type_arg] = take_function_args(name, args)?; - +/// Returns the requested type from the type argument +pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result { let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( "{name} requires its second argument to be a constant string, got {:?}", diff --git a/datafusion/functions/src/core/arrow_try_cast.rs b/datafusion/functions/src/core/arrow_try_cast.rs index 61a5291c05ed9..d27b29ba5736d 100644 --- a/datafusion/functions/src/core/arrow_try_cast.rs +++ b/datafusion/functions/src/core/arrow_try_cast.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use super::arrow_cast::data_type_from_args; +use super::arrow_cast::data_type_from_type_arg; /// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring. /// @@ -127,20 +127,18 @@ impl ScalarUDFImpl for ArrowTryCastFunc { fn simplify( &self, - mut args: Vec, + args: Vec, info: &SimplifyContext, ) -> Result { - let target_type = data_type_from_args(self.name(), &args)?; - // remove second (type) argument - args.pop().unwrap(); - let arg = args.pop().unwrap(); + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = data_type_from_type_arg(self.name(), &type_arg)?; - let source_type = info.get_data_type(&arg)?; + let source_type = info.get_data_type(&source_arg)?; let new_expr = if source_type == target_type { - arg + source_arg } else { Expr::TryCast(datafusion_expr::TryCast { - expr: Box::new(arg), + expr: Box::new(source_arg), field: target_type.into_nullable_field_ref(), }) }; diff --git a/datafusion/functions/src/core/cast_to_type.rs b/datafusion/functions/src/core/cast_to_type.rs new file mode 100644 index 0000000000000..abc7d440e04ba --- /dev/null +++ b/datafusion/functions/src/core/cast_to_type.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CastToTypeFunc`]: Implementation of the `cast_to_type` function + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, internal_err, utils::take_function_args}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Casts the first argument to the data type of the second argument. +/// +/// Only the type of the second argument is used; its value is ignored. +/// This is useful in macros or generic SQL where you need to preserve +/// or match types dynamically. +/// +/// For example: +/// ```sql +/// select cast_to_type('42', NULL::INTEGER); +/// ``` +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.", + syntax_example = "cast_to_type(expression, reference)", + sql_example = r#"```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +```"#, + argument( + name = "expression", + description = "The expression to cast. It can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CastToTypeFunc { + signature: Signature, +} + +impl Default for CastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl CastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CastToTypeFunc { + fn name(&self) -> &str { + "cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [source_field, reference_field] = + take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + // Nullability is inherited only from the first argument (the value + // being cast). The second argument is used solely for its type, so + // its own nullability is irrelevant. The one exception is when the + // target type is Null – that type is inherently nullable. + let nullable = source_field.is_nullable() || target_type == DataType::Null; + Ok(Field::new(self.name(), target_type, nullable).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("cast_to_type should have been simplified to cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = info.get_data_type(&type_arg)?; + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + source_arg + } else { + let nullable = info.nullable(&source_arg)? || target_type == DataType::Null; + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(source_arg), + field: Field::new("", target_type, nullable).into(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index e8737612a1dcf..d3c48573667c9 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -24,6 +24,7 @@ pub mod arrow_cast; pub mod arrow_metadata; pub mod arrow_try_cast; pub mod arrowtypeof; +pub mod cast_to_type; pub mod coalesce; pub mod expr_ext; pub mod getfield; @@ -37,6 +38,7 @@ pub mod nvl2; pub mod overlay; pub mod planner; pub mod r#struct; +pub mod try_cast_to_type; pub mod union_extract; pub mod union_tag; pub mod version; @@ -44,6 +46,8 @@ pub mod version; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast); +make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type); +make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type); make_udf_function!(nullif::NullIfFunc, nullif); make_udf_function!(nvl::NVLFunc, nvl); make_udf_function!(nvl2::NVL2Func, nvl2); @@ -75,6 +79,14 @@ pub mod expr_fn { arrow_try_cast, "Casts a value to a specific Arrow data type, returning NULL if the cast fails", arg1 arg2 + ),( + cast_to_type, + "Casts the first argument to the data type of the second argument", + arg1 arg2 + ),( + try_cast_to_type, + "Casts the first argument to the data type of the second argument, returning NULL on failure", + arg1 arg2 ),( nvl, "Returns value2 if value1 is NULL; otherwise it returns value1", @@ -147,6 +159,8 @@ pub fn functions() -> Vec> { nullif(), arrow_cast(), arrow_try_cast(), + cast_to_type(), + try_cast_to_type(), arrow_metadata(), nvl(), nvl2(), diff --git a/datafusion/functions/src/core/try_cast_to_type.rs b/datafusion/functions/src/core/try_cast_to_type.rs new file mode 100644 index 0000000000000..4c5af4cc6d228 --- /dev/null +++ b/datafusion/functions/src/core/try_cast_to_type.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`TryCastToTypeFunc`]: Implementation of the `try_cast_to_type` function + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{ + Result, datatype::DataTypeExt, internal_err, utils::take_function_args, +}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Like [`cast_to_type`](super::cast_to_type::CastToTypeFunc) but returns NULL +/// on cast failure instead of erroring. +/// +/// This is implemented by simplifying `try_cast_to_type(expr, ref)` into +/// `Expr::TryCast` during optimization. +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored.", + syntax_example = "try_cast_to_type(expression, reference)", + sql_example = r#"```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +```"#, + argument( + name = "expression", + description = "The expression to cast. It can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryCastToTypeFunc { + signature: Signature, +} + +impl Default for TryCastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl TryCastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TryCastToTypeFunc { + fn name(&self) -> &str { + "try_cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // TryCast can always return NULL (on cast failure), so always nullable + let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + Ok(Field::new(self.name(), target_type, true).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("try_cast_to_type should have been simplified to try_cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = info.get_data_type(&type_arg)?; + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + source_arg + } else { + Expr::TryCast(datafusion_expr::TryCast { + expr: Box::new(source_arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/sqllogictest/test_files/cast_to_type.slt b/datafusion/sqllogictest/test_files/cast_to_type.slt new file mode 100644 index 0000000000000..128846c0f5157 --- /dev/null +++ b/datafusion/sqllogictest/test_files/cast_to_type.slt @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +## Tests for cast_to_type function +####### + +# Basic string to integer cast +query I +SELECT cast_to_type('42', 1::INTEGER); +---- +42 + +# String to double cast +query R +SELECT cast_to_type('3.14', 1.0::DOUBLE); +---- +3.14 + +# Integer to string cast +query T +SELECT cast_to_type(42, 'a'::VARCHAR); +---- +42 + +# Integer to double cast +query R +SELECT cast_to_type(42, 0.0::DOUBLE); +---- +42 + +# Same-type is a no-op +query I +SELECT cast_to_type(42, 0::INTEGER); +---- +42 + +# Second argument is a typed NULL double +query R +SELECT cast_to_type('3.14', NULL::DOUBLE); +---- +3.14 + +# Second argument is a typed NULL integer +query I +SELECT cast_to_type(42, NULL::INTEGER); +---- +42 + +# Second argument is a typed NULL string +query T +SELECT cast_to_type('42', NULL::VARCHAR); +---- +42 + +# NULL first argument +query I +SELECT cast_to_type(NULL, 0::INTEGER); +---- +NULL + +# CASE expression as first argument +query I +SELECT cast_to_type(CASE WHEN true THEN '1' ELSE '2' END, NULL::INTEGER); +---- +1 + +# Arithmetic expression as first argument +query R +SELECT cast_to_type(1 + 2, NULL::DOUBLE); +---- +3 + +# Nested cast_to_type +query T +SELECT cast_to_type(cast_to_type('3.14', NULL::DOUBLE), NULL::VARCHAR); +---- +3.14 + +# Subquery as second argument +query I +SELECT cast_to_type('42', (SELECT NULL::INTEGER)); +---- +42 + +# Column reference as second argument +statement ok +CREATE TABLE t1 (int_col INTEGER, text_col VARCHAR, double_col DOUBLE); + +statement ok +INSERT INTO t1 VALUES (1, 'hello', 3.14), (2, 'world', 2.72); + +query I +SELECT cast_to_type('99', int_col) FROM t1 LIMIT 1; +---- +99 + +query T +SELECT cast_to_type(123, text_col) FROM t1 LIMIT 1; +---- +123 + +query R +SELECT cast_to_type('1.5', double_col) FROM t1 LIMIT 1; +---- +1.5 + +# Case statement as second argument +query I +SELECT cast_to_type('42', CASE WHEN random() < 2 THEN 1 ELSE 0 END); +---- +42 + +# Use with column values as first argument +query R +SELECT cast_to_type(int_col, 1.0::DOUBLE) FROM t1; +---- +1 +2 + +# Cast column to match another column's type +query T +SELECT cast_to_type(int_col, text_col) FROM t1; +---- +1 +2 + +# Boolean cast +query B +SELECT cast_to_type(1, NULL::BOOLEAN); +---- +true + +# String to date cast +query D +SELECT cast_to_type('2024-01-15', NULL::DATE); +---- +2024-01-15 + +# Error on invalid cast +statement error Cannot cast string 'not_a_number' to value of Int32 type +SELECT cast_to_type('not_a_number', NULL::INTEGER); + +# Error on invalid target type +statement error Unsupported SQL type INVALID +SELECT cast_to_type('42', NULL::INVALID); + +statement ok +DROP TABLE t1; + +####### +## Nullability tests for cast_to_type +####### + +statement ok +set datafusion.catalog.information_schema = true; + +# Non-nullable input -> non-nullable output +statement ok +CREATE VIEW v_cast_nonnull AS SELECT cast_to_type(42, NULL::INTEGER) as a; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_cast_nonnull'; +---- +a Int32 NO + +statement ok +DROP VIEW v_cast_nonnull; + +# Nullable input -> nullable output +statement ok +CREATE TABLE t_nullable (x INTEGER); + +statement ok +INSERT INTO t_nullable VALUES (1), (NULL); + +statement ok +CREATE VIEW v_cast_null AS SELECT cast_to_type(x, 1.0::DOUBLE) as a FROM t_nullable; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_cast_null'; +---- +a Float64 YES + +# If we cast to the null type itself the result is nullable even if the input is not +statement ok +CREATE VIEW v_cast_to_null AS SELECT cast_to_type(42, null) as a; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_cast_to_null'; +---- +a Null YES + +statement ok +DROP VIEW v_cast_null; + +statement ok +DROP TABLE t_nullable; + +####### +## Tests for try_cast_to_type function (fallible variant returning NULL) +####### + +# Basic string to integer cast +query I +SELECT try_cast_to_type('42', NULL::INTEGER); +---- +42 + +# Invalid cast returns NULL instead of error +query I +SELECT try_cast_to_type('not_a_number', NULL::INTEGER); +---- +NULL + +# String to double cast +query R +SELECT try_cast_to_type('3.14', NULL::DOUBLE); +---- +3.14 + +# Invalid double returns NULL +query R +SELECT try_cast_to_type('abc', NULL::DOUBLE); +---- +NULL + +# Integer to string cast (always succeeds) +query T +SELECT try_cast_to_type(42, NULL::VARCHAR); +---- +42 + +# Same-type is a no-op +query I +SELECT try_cast_to_type(42, 0::INTEGER); +---- +42 + +# NULL first argument +query I +SELECT try_cast_to_type(NULL, 0::INTEGER); +---- +NULL + +# CASE expression as first argument +query I +SELECT try_cast_to_type(CASE WHEN true THEN '1' ELSE '2' END, NULL::INTEGER); +---- +1 + +# Arithmetic expression as first argument +query R +SELECT try_cast_to_type(1 + 2, NULL::DOUBLE); +---- +3 + +# Nested: try_cast_to_type inside cast_to_type +query T +SELECT cast_to_type(try_cast_to_type('3.14', NULL::DOUBLE), NULL::VARCHAR); +---- +3.14 + +# Subquery as second argument +query I +SELECT try_cast_to_type('42', (SELECT NULL::INTEGER)); +---- +42 + +# Column reference as second argument +statement ok +CREATE TABLE t2 (int_col INTEGER, text_col VARCHAR); + +statement ok +INSERT INTO t2 VALUES (1, 'hello'), (2, 'world'); + +query I +SELECT try_cast_to_type('99', int_col) FROM t2 LIMIT 1; +---- +99 + +query I +SELECT try_cast_to_type(text_col, int_col) FROM t2; +---- +NULL +NULL + +# Cast column to match another column's type +query T +SELECT try_cast_to_type(int_col, text_col) FROM t2; +---- +1 +2 + +# Boolean cast +query B +SELECT try_cast_to_type(1, NULL::BOOLEAN); +---- +true + +# String to date - valid +query D +SELECT try_cast_to_type('2024-01-15', NULL::DATE); +---- +2024-01-15 + +# String to date - invalid returns NULL +query D +SELECT try_cast_to_type('not_a_date', NULL::DATE); +---- +NULL + +statement ok +DROP TABLE t2; + +####### +## Nullability tests for try_cast_to_type +####### + +# try_cast_to_type is always nullable (cast can fail) +statement ok +CREATE VIEW v_trycast AS SELECT try_cast_to_type(42, 1::INTEGER) as a; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_trycast'; +---- +a Int32 YES + +statement ok +DROP VIEW v_trycast; + +statement ok +set datafusion.catalog.information_schema = false; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c303b43fc8844..d1b80f1f90b8b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -5285,7 +5285,9 @@ union_tag(union_expression) - [arrow_metadata](#arrow_metadata) - [arrow_try_cast](#arrow_try_cast) - [arrow_typeof](#arrow_typeof) +- [cast_to_type](#cast_to_type) - [get_field](#get_field) +- [try_cast_to_type](#try_cast_to_type) - [version](#version) ### `arrow_cast` @@ -5405,6 +5407,37 @@ arrow_typeof(expression) +---------------------------+------------------------+ ``` +### `cast_to_type` + +Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored. + +```sql +cast_to_type(expression, reference) +``` + +#### Arguments + +- **expression**: The expression to cast. It can be a constant, column, or function, and any combination of operators. +- **reference**: Reference expression whose data type determines the target cast type. The value is ignored. + +#### Example + +```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +``` + ### `get_field` Returns a field within a map or a struct with the given key. @@ -5457,6 +5490,32 @@ get_field(expression, field_name[, field_name2, ...]) +--------+ ``` +### `try_cast_to_type` + +Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored. + +```sql +try_cast_to_type(expression, reference) +``` + +#### Arguments + +- **expression**: The expression to cast. It can be a constant, column, or function, and any combination of operators. +- **reference**: Reference expression whose data type determines the target cast type. The value is ignored. + +#### Example + +```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +``` + ### `version` Returns the version of DataFusion. From 249c23c2b630fc1e43fc0de5e6ccc7d4c308a00b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 9 Apr 2026 12:20:28 -0400 Subject: [PATCH 4/4] Introduce Morselizer API, rewrite `ParquetOpener` to `ParquetMorselizer` (#21327) ~(Draft until I am sure I can use this API to make FileStream behave better)~ ## Which issue does this PR close? - part of https://github.com/apache/datafusion/issues/20529 - Needed for https://github.com/apache/datafusion/pull/21351 - Broken out of https://github.com/apache/datafusion/pull/20820 - Closes https://github.com/apache/datafusion/pull/21427 ## Rationale for this change I can get 10% faster on many ClickBench queries by reordeirng files at runtime. You can see it all working together here: https://github.com/apache/datafusion/pull/21351 To do do, I need to rework the FileStream so that it can reorder operations at runtime. Eventually that will include both CPU and IO. This PR is a step in the direction by introducing the main Morsel API and implementing it for Parquet. The next PR (https://github.com/apache/datafusion/pull/21342) rewrites FileStream in terms of the Morsel API ## What changes are included in this PR? 1. Add proposed `Morsel` API 2. Rewrite Parquet opener in terms of that API 3. Add an adapter layer (back to FileOpener, so I don't have to rewrite FileStream in the same PR) My next PR will rewrite the FileStream to use the Morsel API ## Are these changes tested? Yes by existing CI. I will work on adding additional tests for just Parquet opener in a follow on PR ## Are there any user-facing changes? No --- datafusion/datasource-parquet/src/opener.rs | 357 ++++++++++++++------ datafusion/datasource-parquet/src/source.rs | 56 +-- datafusion/datasource/src/mod.rs | 1 + datafusion/datasource/src/morsel/mod.rs | 229 +++++++++++++ 4 files changed, 512 insertions(+), 131 deletions(-) create mode 100644 datafusion/datasource/src/morsel/mod.rs diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 6621706c35c81..35900e16c18ed 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ParquetOpener`] state machine for opening Parquet files +//! [`ParquetOpener`] and [`ParquetMorselizer`] state machines for opening Parquet files use crate::page_filter::PagePruningAccessPlanFilter; use crate::row_filter::build_projection_read_plan; @@ -26,11 +26,16 @@ use crate::{ }; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::DataType; +use datafusion_common::internal_err; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; +use datafusion_datasource::morsel::{ + Morsel, MorselPlan, MorselPlanner, Morselizer, PendingMorselPlanner, +}; use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; use datafusion_physical_expr::utils::reassign_expr_columns; use datafusion_physical_expr_adapter::replace_columns_with_literals; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::fmt; use std::future::Future; use std::mem; use std::pin::Pin; @@ -77,12 +82,26 @@ use parquet::bloom_filter::Sbbf; use parquet::errors::ParquetError; use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader}; -/// Entry point for opening a Parquet file +/// Implements [`FileOpener`] for Parquet +#[derive(Clone)] +pub(super) struct ParquetOpener { + pub(super) morselizer: ParquetMorselizer, +} + +impl FileOpener for ParquetOpener { + fn open(&self, partitioned_file: PartitionedFile) -> Result { + let future = ParquetOpenFuture::new(&self.morselizer, partitioned_file)?; + Ok(Box::pin(future)) + } +} + +/// Stateless Parquet morselizer implementation. /// /// Reading a Parquet file is a multi-stage process, with multiple CPU-intensive /// steps interspersed with I/O steps. The code in this module implements the steps /// as an explicit state machine -- see [`ParquetOpenState`] for details. -pub(super) struct ParquetOpener { +#[derive(Clone)] +pub(super) struct ParquetMorselizer { /// Execution partition index pub(crate) partition_index: usize, /// Projection to apply on top of the table schema (i.e. can reference partition columns). @@ -137,6 +156,23 @@ pub(super) struct ParquetOpener { pub reverse_row_groups: bool, } +impl fmt::Debug for ParquetMorselizer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetMorselizer") + .field("partition_index", &self.partition_index) + .field("preserve_order", &self.preserve_order) + .field("enable_page_index", &self.enable_page_index) + .field("enable_bloom_filter", &self.enable_bloom_filter) + .finish() + } +} + +impl Morselizer for ParquetMorselizer { + fn plan_file(&self, file: PartitionedFile) -> Result> { + Ok(Box::new(ParquetMorselPlanner::try_new(self, file)?)) + } +} + /// States for [`ParquetOpenFuture`] /// /// These states correspond to the steps required to read and apply various @@ -216,6 +252,27 @@ enum ParquetOpenState { Done, } +impl fmt::Debug for ParquetOpenState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = match self { + ParquetOpenState::Start { .. } => "Start", + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(_) => "LoadEncryption", + ParquetOpenState::PruneFile(_) => "PruneFile", + ParquetOpenState::LoadMetadata(_) => "LoadMetadata", + ParquetOpenState::PrepareFilters(_) => "PrepareFilters", + ParquetOpenState::LoadPageIndex(_) => "LoadPageIndex", + ParquetOpenState::PruneWithStatistics(_) => "PruneWithStatistics", + ParquetOpenState::LoadBloomFilters(_) => "LoadBloomFilters", + ParquetOpenState::PruneWithBloomFilters(_) => "PruneWithBloomFilters", + ParquetOpenState::BuildStream(_) => "BuildStream", + ParquetOpenState::Ready(_) => "Ready", + ParquetOpenState::Done => "Done", + }; + f.write_str(state) + } +} + struct PreparedParquetOpen { partition_index: usize, partitioned_file: PartitionedFile, @@ -290,37 +347,13 @@ struct BloomFiltersLoadedParquetOpen { row_group_bloom_filters: Vec, } -/// Implements state machine described in [`ParquetOpenState`] -struct ParquetOpenFuture { - state: ParquetOpenState, -} - -impl ParquetOpenFuture { - #[cfg(feature = "parquet_encryption")] - fn new(prepared: PreparedParquetOpen, encryption_context: EncryptionContext) -> Self { - Self { - state: ParquetOpenState::Start { - prepared: Box::new(prepared), - encryption_context: Arc::new(encryption_context), - }, - } - } - - #[cfg(not(feature = "parquet_encryption"))] - fn new(prepared: PreparedParquetOpen) -> Self { - Self { - state: ParquetOpenState::Start { - prepared: Box::new(prepared), - }, - } - } -} - impl ParquetOpenState { /// Applies one CPU-only state transition. /// /// `Load*` states do not transition here and are returned unchanged so the /// driver loop can poll their inner futures separately. + /// + /// Implements state machine described in [`ParquetOpenState`] fn transition(self) -> Result { match self { ParquetOpenState::Start { @@ -392,93 +425,208 @@ impl ParquetOpenState { } } +/// Adapter for a [`MorselPlanner`] to the [`FileOpener`] API +/// +/// Compatibility adapter that drives a morsel planner through the +/// [`FileOpener`] API. +struct ParquetOpenFuture { + planner: Option>, + pending_io: Option, + ready_morsels: VecDeque>, +} + +impl ParquetOpenFuture { + fn new( + morselizer: &ParquetMorselizer, + partitioned_file: PartitionedFile, + ) -> Result { + Ok(Self { + planner: Some(morselizer.plan_file(partitioned_file)?), + pending_io: None, + ready_morsels: VecDeque::new(), + }) + } +} + impl Future for ParquetOpenFuture { type Output = Result>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { - let state = mem::replace(&mut self.state, ParquetOpenState::Done); - let mut state = state.transition()?; - - match state { - #[cfg(feature = "parquet_encryption")] - ParquetOpenState::LoadEncryption(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => ParquetOpenState::PruneFile(result?), - Poll::Pending => { - self.state = ParquetOpenState::LoadEncryption(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::LoadMetadata(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => { - ParquetOpenState::PrepareFilters(Box::new(result?)) - } - Poll::Pending => { - self.state = ParquetOpenState::LoadMetadata(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::LoadPageIndex(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => { - ParquetOpenState::PruneWithStatistics(Box::new(result?)) - } - Poll::Pending => { - self.state = ParquetOpenState::LoadPageIndex(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::LoadBloomFilters(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => { - ParquetOpenState::PruneWithBloomFilters(Box::new(result?)) - } - Poll::Pending => { - self.state = ParquetOpenState::LoadBloomFilters(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::Ready(stream) => { - return Poll::Ready(Ok(stream)); - } - ParquetOpenState::Done => { - return Poll::Ready(Ok(futures::stream::empty().boxed())); + // If planner I/O completed, resume with the returned planner. + if let Some(io_future) = self.pending_io.as_mut() { + let maybe_planner = ready!(io_future.poll_unpin(cx)); + // Clear `pending_io` before handling the result so an error + // cannot leave both continuation paths populated. + self.pending_io = None; + if self.planner.is_some() { + return Poll::Ready(internal_err!( + "ParquetOpenFuture does not support concurrent planners" + )); } + self.planner = Some(maybe_planner?); + } + + // If a stream morsel is ready, return it. + if let Some(morsel) = self.ready_morsels.pop_front() { + return Poll::Ready(Ok(morsel.into_stream())); + } - // For all other states, loop again and try to transition - // immediately. All states are explicitly listed here to ensure any - // new states are handled correctly - ParquetOpenState::Start { .. } => {} - ParquetOpenState::PruneFile(_) => {} - ParquetOpenState::PrepareFilters(_) => {} - ParquetOpenState::PruneWithStatistics(_) => {} - ParquetOpenState::PruneWithBloomFilters(_) => {} - ParquetOpenState::BuildStream(_) => {} + // This shim must always own either a planner, a pending planner + // future, or a ready morsel. Reaching this branch means the + // continuation was lost. + let Some(planner) = self.planner.take() else { + return Poll::Ready(internal_err!( + "ParquetOpenFuture polled after completion" + )); }; - self.state = state; + // Planner completed without producing a stream morsel. + // (e.g. all row groups were pruned) + let Some(mut plan) = planner.plan()? else { + return Poll::Ready(Ok(futures::stream::empty().boxed())); + }; + + let mut child_planners = plan.take_ready_planners(); + if child_planners.len() > 1 { + return Poll::Ready(internal_err!( + "Parquet FileOpener adapter does not support child morsel planners" + )); + } + self.planner = child_planners.pop(); + + self.ready_morsels = plan.take_morsels().into(); + + if let Some(io_future) = plan.take_pending_planner() { + self.pending_io = Some(io_future); + } } } } -impl FileOpener for ParquetOpener { - fn open(&self, partitioned_file: PartitionedFile) -> Result { - let prepared = self.prepare_open_file(partitioned_file)?; +/// Implements the Morsel API +struct ParquetStreamMorsel { + stream: BoxStream<'static, Result>, +} + +impl ParquetStreamMorsel { + fn new(stream: BoxStream<'static, Result>) -> Self { + Self { stream } + } +} + +impl fmt::Debug for ParquetStreamMorsel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetStreamMorsel") + .finish_non_exhaustive() + } +} + +impl Morsel for ParquetStreamMorsel { + fn into_stream(self: Box) -> BoxStream<'static, Result> { + self.stream + } +} + +/// Per-file planner that owns the current [`ParquetOpenState`]. +struct ParquetMorselPlanner { + /// Ready to perform CPU-only planning work. + state: ParquetOpenState, +} + +impl fmt::Debug for ParquetMorselPlanner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ParquetMorselPlanner::Ready") + .field(&self.state) + .finish() + } +} + +impl ParquetMorselPlanner { + fn try_new(morselizer: &ParquetMorselizer, file: PartitionedFile) -> Result { + let prepared = morselizer.prepare_open_file(file)?; #[cfg(feature = "parquet_encryption")] - let future = ParquetOpenFuture::new(prepared, self.get_encryption_context()); + let state = ParquetOpenState::Start { + prepared: Box::new(prepared), + encryption_context: Arc::new(morselizer.get_encryption_context()), + }; #[cfg(not(feature = "parquet_encryption"))] - let future = ParquetOpenFuture::new(prepared); - Ok(Box::pin(future)) + let state = ParquetOpenState::Start { + prepared: Box::new(prepared), + }; + Ok(Self { state }) + } + + /// Schedule an I/O future that resolves to the next planner to run. + /// + /// This helper + /// + /// 1. drives one I/O phase to completion + /// 2. wraps the resulting state in a new [`ParquetMorselPlanner`] + /// 3. returns a [`MorselPlan`] containing the boxed future for the caller + /// to poll + /// + fn schedule_io(future: F) -> MorselPlan + where + F: Future> + Send + 'static, + { + let io_future = async move { + let next_state = future.await?; + Ok(Box::new(ParquetMorselPlanner { state: next_state }) as _) + }; + MorselPlan::new().with_pending_planner(io_future) + } +} + +impl MorselPlanner for ParquetMorselPlanner { + fn plan(self: Box) -> Result> { + if let ParquetOpenState::Done = self.state { + return Ok(None); + } + + let state = self.state.transition()?; + + match state { + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneFile(future.await?)) + }))) + } + ParquetOpenState::LoadMetadata(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PrepareFilters(Box::new(future.await?))) + }))) + } + ParquetOpenState::LoadPageIndex(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneWithStatistics(Box::new( + future.await?, + ))) + }))) + } + ParquetOpenState::LoadBloomFilters(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneWithBloomFilters(Box::new( + future.await?, + ))) + }))) + } + ParquetOpenState::Ready(stream) => { + let morsels: Vec> = + vec![Box::new(ParquetStreamMorsel::new(stream))]; + Ok(Some(MorselPlan::new().with_morsels(morsels))) + } + ParquetOpenState::Done => Ok(None), + cpu_state => Ok(Some( + MorselPlan::new() + .with_planners(vec![Box::new(Self { state: cpu_state })]), + )), + } } } -impl ParquetOpener { +impl ParquetMorselizer { /// Perform the CPU-only setup for opening a parquet file. fn prepare_open_file( &self, @@ -1447,7 +1595,7 @@ impl EncryptionContext { } } -impl ParquetOpener { +impl ParquetMorselizer { #[cfg(feature = "parquet_encryption")] fn get_encryption_context(&self) -> EncryptionContext { EncryptionContext::new( @@ -1576,7 +1724,7 @@ fn should_enable_page_index( mod test { use std::sync::Arc; - use super::{ConstantColumns, constant_columns_from_stats}; + use super::{ConstantColumns, ParquetMorselizer, constant_columns_from_stats}; use crate::{DefaultParquetFileReaderFactory, RowGroupAccess, opener::ParquetOpener}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use bytes::{BufMut, BytesMut}; @@ -1731,11 +1879,12 @@ mod test { ProjectionExprs::from_indices(&all_indices, &file_schema) }; - ParquetOpener { + let morselizer = ParquetMorselizer { partition_index: self.partition_index, projection, batch_size: self.batch_size, limit: self.limit, + preserve_order: self.preserve_order, predicate: self.predicate, table_schema, metadata_size_hint: self.metadata_size_hint, @@ -1757,8 +1906,8 @@ mod test { encryption_factory: None, max_predicate_cache_size: self.max_predicate_cache_size, reverse_row_groups: self.reverse_row_groups, - preserve_order: self.preserve_order, - } + }; + ParquetOpener { morselizer } } } diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 3a64137a2a3f8..1e54e98dfd04b 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use crate::DefaultParquetFileReaderFactory; use crate::ParquetFileReaderFactory; -use crate::opener::ParquetOpener; use crate::opener::build_pruning_predicates; +use crate::opener::{ParquetMorselizer, ParquetOpener}; use crate::row_filter::can_expr_be_pushed_down_with_schemas; use datafusion_common::config::ConfigOptions; #[cfg(feature = "parquet_encryption")] @@ -543,32 +543,34 @@ impl FileSource for ParquetSource { .map(|time_unit| parse_coerce_int96_string(time_unit.as_str()).unwrap()); let opener = Arc::new(ParquetOpener { - partition_index: partition, - projection: self.projection.clone(), - batch_size: self - .batch_size - .expect("Batch size must set before creating ParquetOpener"), - limit: base_config.limit, - preserve_order: base_config.preserve_order, - predicate: self.predicate.clone(), - table_schema: self.table_schema.clone(), - metadata_size_hint: self.metadata_size_hint, - metrics: self.metrics().clone(), - parquet_file_reader_factory, - pushdown_filters: self.pushdown_filters(), - reorder_filters: self.reorder_filters(), - force_filter_selections: self.force_filter_selections(), - enable_page_index: self.enable_page_index(), - enable_bloom_filter: self.bloom_filter_on_read(), - enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, - coerce_int96, - #[cfg(feature = "parquet_encryption")] - file_decryption_properties, - expr_adapter_factory, - #[cfg(feature = "parquet_encryption")] - encryption_factory: self.get_encryption_factory_with_config(), - max_predicate_cache_size: self.max_predicate_cache_size(), - reverse_row_groups: self.reverse_row_groups, + morselizer: ParquetMorselizer { + partition_index: partition, + projection: self.projection.clone(), + batch_size: self + .batch_size + .expect("Batch size must set before creating ParquetOpener"), + limit: base_config.limit, + preserve_order: base_config.preserve_order, + predicate: self.predicate.clone(), + table_schema: self.table_schema.clone(), + metadata_size_hint: self.metadata_size_hint, + metrics: self.metrics().clone(), + parquet_file_reader_factory, + pushdown_filters: self.pushdown_filters(), + reorder_filters: self.reorder_filters(), + force_filter_selections: self.force_filter_selections(), + enable_page_index: self.enable_page_index(), + enable_bloom_filter: self.bloom_filter_on_read(), + enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, + coerce_int96, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties, + expr_adapter_factory, + #[cfg(feature = "parquet_encryption")] + encryption_factory: self.get_encryption_factory_with_config(), + max_predicate_cache_size: self.max_predicate_cache_size(), + reverse_row_groups: self.reverse_row_groups, + }, }); Ok(opener) } diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index bcc4627050d4a..a9600271c28ce 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -38,6 +38,7 @@ pub mod file_scan_config; pub mod file_sink_config; pub mod file_stream; pub mod memory; +pub mod morsel; pub mod projection; pub mod schema_adapter; pub mod sink; diff --git a/datafusion/datasource/src/morsel/mod.rs b/datafusion/datasource/src/morsel/mod.rs new file mode 100644 index 0000000000000..5f200d7794690 --- /dev/null +++ b/datafusion/datasource/src/morsel/mod.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Structures for Morsel Driven IO. +//! +//! NOTE: As of DataFusion 54.0.0, these are experimental APIs that may change +//! substantially. +//! +//! Morsel Driven IO is a technique for parallelizing the reading of large files +//! by dividing them into smaller "morsels" that are processed independently. +//! +//! It is inspired by the paper [Morsel-Driven Parallelism: A NUMA-Aware Query +//! Evaluation Framework for the Many-Core Age](https://db.in.tum.de/~leis/papers/morsels.pdf). + +use crate::PartitionedFile; +use arrow::array::RecordBatch; +use datafusion_common::Result; +use futures::FutureExt; +use futures::future::BoxFuture; +use futures::stream::BoxStream; +use std::fmt::Debug; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A Morsel of work ready to resolve to a stream of [`RecordBatch`]es. +/// +/// This represents a single morsel of work that is ready to be processed. It +/// has all data necessary (does not need any I/O) and is ready to be turned +/// into a stream of [`RecordBatch`]es for processing by the execution engine. +pub trait Morsel: Send + Debug { + /// Consume this morsel and produce a stream of [`RecordBatch`]es for processing. + /// + /// Note: This may do CPU work to decode already-loaded data, but should not + /// do any I/O work such as reading from the file. + fn into_stream(self: Box) -> BoxStream<'static, Result>; +} + +/// A Morselizer takes a single [`PartitionedFile`] and creates the initial planner +/// for that file. +/// +/// This is the entry point for morsel driven I/O. +pub trait Morselizer: Send + Sync + Debug { + /// Return the initial [`MorselPlanner`] for this file. + /// + /// Morselizing a file may involve CPU work, such as parsing parquet + /// metadata and evaluating pruning predicates. It should NOT do any I/O + /// work, such as reading from the file. Any needed I/O should be done using + /// [`MorselPlan::with_pending_planner`]. + fn plan_file(&self, file: PartitionedFile) -> Result>; +} + +/// A Morsel Planner is responsible for creating morsels for a given scan. +/// +/// The [`MorselPlanner`] is the unit of I/O. There is only ever a single I/O +/// outstanding for a specific planner. DataFusion may run +/// multiple planners in parallel, which corresponds to multiple parallel +/// I/O requests. +/// +/// It is not a Rust `Stream` so that it can explicitly separate CPU bound +/// work from I/O work. +/// +/// The design is similar to `ParquetPushDecoder`: when `plan` is called, it +/// should do CPU work to produce the next morsels or discover the next I/O +/// phase. +/// +/// Best practice is to spawn I/O in a Tokio task on a separate runtime to +/// ensure that CPU work doesn't block or slow down I/O work, but this is not +/// strictly required by the API. +pub trait MorselPlanner: Send + Debug { + /// Attempt to plan morsels. This may involve CPU work, such as parsing + /// parquet metadata and evaluating pruning predicates. + /// + /// It should NOT do any I/O work, such as reading from the file. If I/O is + /// required, the returned [`MorselPlan`] should contain a pending planner + /// future that the caller polls to drive the I/O work to completion. Once + /// that future resolves, it yields a planner ready for work. + /// + /// Note this function is **not async** to make it explicitly clear that if + /// I/O is required, it should be done in the returned `io_future`. + /// + /// Returns `None` if the planner has no more work to do. + /// + /// # Empty Morsel Plans + /// + /// It may return `None`, which means no batches will be read from the file + /// (e.g. due to late-pruning based on statistics). + /// + /// # Output Ordering + /// + /// See the comments on [`MorselPlan`] for the logical output order. + fn plan(self: Box) -> Result>; +} + +/// Return result of [`MorselPlanner::plan`]. +/// +/// # Logical Ordering +/// +/// For plans where the output order of rows is maintained, the output order of +/// a [`MorselPlanner`] is logically defined as follows: +/// 1. All morsels that are directly produced +/// 2. Recursively, all morsels produced by the returned `planners` +#[derive(Default)] +pub struct MorselPlan { + /// Morsels ready for CPU work + morsels: Vec>, + /// Planners that are ready for CPU work. + ready_planners: Vec>, + /// A future with planner I/O that resolves to a CPU ready planner. + /// + /// DataFusion will poll this future occasionally to drive the I/O work to + /// completion. Once it resolves, planning continues with the returned + /// planner. + pending_planner: Option, +} + +impl MorselPlan { + /// Create an empty morsel plan. + pub fn new() -> Self { + Self::default() + } + + /// Set the ready morsels. + pub fn with_morsels(mut self, morsels: Vec>) -> Self { + self.morsels = morsels; + self + } + + /// Set the ready child planners. + pub fn with_planners(mut self, planners: Vec>) -> Self { + self.ready_planners = planners; + self + } + + /// Set the pending planner for an I/O phase. + pub fn with_pending_planner(mut self, io_future: F) -> Self + where + F: Future>> + Send + 'static, + { + self.pending_planner = Some(PendingMorselPlanner::new(io_future)); + self + } + + /// Set the pending planner for an I/O phase. + pub fn set_pending_planner(&mut self, io_future: F) + where + F: Future>> + Send + 'static, + { + self.pending_planner = Some(PendingMorselPlanner::new(io_future)); + } + + /// Take the ready morsels. + pub fn take_morsels(&mut self) -> Vec> { + std::mem::take(&mut self.morsels) + } + + /// Take the ready child planners. + pub fn take_ready_planners(&mut self) -> Vec> { + std::mem::take(&mut self.ready_planners) + } + + /// Take the pending I/O future, if any. + pub fn take_pending_planner(&mut self) -> Option { + self.pending_planner.take() + } + + /// Returns `true` if this plan contains an I/O future. + pub fn has_io_future(&self) -> bool { + self.pending_planner.is_some() + } +} + +/// Wrapper for I/O that must complete before planning can continue. +pub struct PendingMorselPlanner { + future: BoxFuture<'static, Result>>, +} + +impl PendingMorselPlanner { + /// Create a new pending planner future. + /// + /// Example + /// ``` + /// # use datafusion_common::DataFusionError; + /// # use datafusion_datasource::morsel::{MorselPlanner, PendingMorselPlanner}; + /// let work = async move { + /// let planner: Box = { + /// // Do I/O work here, then return the next planner to run. + /// # unimplemented!(); + /// }; + /// Ok(planner) as Result<_, DataFusionError>; + /// }; + /// let pending_io = PendingMorselPlanner::new(work); + /// ``` + pub fn new(future: F) -> Self + where + F: Future>> + Send + 'static, + { + Self { + future: future.boxed(), + } + } + + /// Consume this wrapper and return the underlying future. + pub fn into_future(self) -> BoxFuture<'static, Result>> { + self.future + } +} + +/// Forwards polling to the underlying future. +impl Future for PendingMorselPlanner { + type Output = Result>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // forward request to inner + self.future.as_mut().poll(cx) + } +}