From 68f74aecde2044f3e98a28561fe7bc099340a120 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 29 Nov 2025 21:17:49 -0800 Subject: [PATCH 1/4] prototype --- native/spark-expr/src/comet_scalar_funcs.rs | 4 +- native/spark-expr/src/string_funcs/mod.rs | 2 + .../src/string_funcs/regexp_extract.rs | 401 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 2 + .../org/apache/comet/serde/strings.scala | 100 ++++- .../comet/CometStringExpressionSuite.scala | 122 ++++++ 6 files changed, 629 insertions(+), 2 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/regexp_extract.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 021bb1c78f..45b4ca8aad 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,7 +23,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, - SparkDateTrunc, SparkStringSpace, + SparkDateTrunc, SparkRegExpExtract, SparkRegExpExtractAll, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -199,6 +199,8 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtract::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtractAll::default())), ] } diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..2026ec5fec 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod regexp_extract; mod string_space; mod substring; +pub use regexp_extract::{SparkRegExpExtract, SparkRegExpExtractAll}; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs new file mode 100644 index 0000000000..2a9ce6b82c --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -0,0 +1,401 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericStringArray}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use regex::Regex; +use std::sync::Arc; +use std::any::Any; + +/// Spark-compatible regexp_extract function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtract { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkRegExpExtract { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // regexp_extract always returns String + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // regexp_extract(subject, pattern, idx) + if args.args.len() != 3 { + return exec_err!( + "regexp_extract expects 3 arguments, got {}", + args.args.len() + ); + } + + let subject = &args.args[0]; + let pattern = &args.args[1]; + let idx = &args.args[2]; + + // Pattern must be a literal string + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("regexp_extract pattern must be a string literal"); + } + }; + + // idx must be a literal int + let idx_val = match idx { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, + _ => { + return exec_err!("regexp_extract idx must be an integer literal"); + } + }; + + // Compile regex once + let regex = Regex::new(&pattern_str).map_err(|e| { + internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) + })?; + + match subject { + ColumnarValue::Array(array) => { + let result = regexp_extract_array(array, ®ex, idx_val)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + let result = match s { + Some(text) => Some(extract_group(text, ®ex, idx_val)), + None => None, // NULL input → NULL output + }; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => exec_err!("regexp_extract expects string input"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Spark-compatible regexp_extract_all function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtractAll { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkRegExpExtractAll { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtractAll { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtractAll { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // regexp_extract_all returns Array + Ok(DataType::List(Arc::new( + arrow::datatypes::Field::new("item", DataType::Utf8, true), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // regexp_extract_all(subject, pattern) or regexp_extract_all(subject, pattern, idx) + if args.args.len() < 2 || args.args.len() > 3 { + return exec_err!( + "regexp_extract_all expects 2 or 3 arguments, got {}", + args.args.len() + ); + } + + let subject = &args.args[0]; + let pattern = &args.args[1]; + let idx_val = if args.args.len() == 3 { + match &args.args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, + _ => { + return exec_err!("regexp_extract_all idx must be an integer literal"); + } + } + } else { + 0 // default to group 0 (entire match) + }; + + // Pattern must be a literal string + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("regexp_extract_all pattern must be a string literal"); + } + }; + + // Compile regex once + let regex = Regex::new(&pattern_str).map_err(|e| { + internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) + })?; + + match subject { + ColumnarValue::Array(array) => { + let result = regexp_extract_all_array(array, ®ex, idx_val)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + match s { + Some(text) => { + let matches = extract_all_groups(text, ®ex, idx_val); + // Build a list array with a single element + let mut list_builder = + arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + let list_array = list_builder.finish(); + + Ok(ColumnarValue::Scalar(ScalarValue::List( + Arc::new(list_array), + ))) + } + None => { + // Return NULL list using try_into (same as planner.rs:424) + let null_list: ScalarValue = DataType::List(Arc::new( + arrow::datatypes::Field::new("item", DataType::Utf8, true) + )).try_into()?; + Ok(ColumnarValue::Scalar(null_list)) + } + } + } + _ => exec_err!("regexp_extract_all expects string input"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +// Helper functions + +fn extract_group(text: &str, regex: &Regex, idx: usize) -> String { + regex + .captures(text) + .and_then(|caps| caps.get(idx)) + .map(|m| m.as_str().to_string()) + // Spark behavior: return empty string "" if no match or group not found + .unwrap_or_else(|| String::new()) +} + +fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + internal_datafusion_err!("regexp_extract expects string array input") + })?; + + let result: GenericStringArray = string_array + .iter() + .map(|s| s.map(|text| extract_group(text, regex, idx))) // NULL → None, non-NULL → Some("") + .collect(); + + Ok(Arc::new(result)) +} + +fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Vec { + regex + .captures_iter(text) + .filter_map(|caps| caps.get(idx).map(|m| m.as_str().to_string())) + .collect() +} + +fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + internal_datafusion_err!("regexp_extract_all expects string array input") + })?; + + let mut list_builder = + arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + + for s in string_array.iter() { + match s { + Some(text) => { + let matches = extract_all_groups(text, regex, idx); + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + } + None => { + list_builder.append(false); + } + } + } + + Ok(Arc::new(list_builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_regexp_extract_basic() { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + + // Spark behavior: return "" on no match, not None + assert_eq!(extract_group("123-abc", ®ex, 0), "123-abc"); + assert_eq!(extract_group("123-abc", ®ex, 1), "123"); + assert_eq!(extract_group("123-abc", ®ex, 2), "abc"); + assert_eq!(extract_group("123-abc", ®ex, 3), ""); // no such group → "" + assert_eq!(extract_group("no match", ®ex, 0), ""); // no match → "" + } + + #[test] + fn test_regexp_extract_all_basic() { + let regex = Regex::new(r"(\d+)").unwrap(); + + // Multiple matches + let matches = extract_all_groups("a1b2c3", ®ex, 0); + assert_eq!(matches, vec!["1", "2", "3"]); + + // Same with group index 1 + let matches = extract_all_groups("a1b2c3", ®ex, 1); + assert_eq!(matches, vec!["1", "2", "3"]); + + // No match + let matches = extract_all_groups("no digits", ®ex, 0); + assert!(matches.is_empty()); + assert_eq!(matches, Vec::::new()); + } + + #[test] + fn test_regexp_extract_all_array() -> Result<()> { + use datafusion::common::cast::as_list_array; + + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("a1b2"), + Some("no digits"), + None, + Some("c3d4e5"), + ])) as ArrayRef; + + let result = regexp_extract_all_array(&array, ®ex, 0)?; + let list_array = as_list_array(&result)?; + + // Row 0: "a1b2" → ["1", "2"] + let row0 = list_array.value(0); + let row0_str = row0.as_any().downcast_ref::>().unwrap(); + assert_eq!(row0_str.len(), 2); + assert_eq!(row0_str.value(0), "1"); + assert_eq!(row0_str.value(1), "2"); + + // Row 1: "no digits" → [] (empty array, not NULL) + let row1 = list_array.value(1); + let row1_str = row1.as_any().downcast_ref::>().unwrap(); + assert_eq!(row1_str.len(), 0); // Empty array + assert!(!list_array.is_null(1)); // Not NULL, just empty + + // Row 2: NULL input → NULL output + assert!(list_array.is_null(2)); + + // Row 3: "c3d4e5" → ["3", "4", "5"] + let row3 = list_array.value(3); + let row3_str = row3.as_any().downcast_ref::>().unwrap(); + assert_eq!(row3_str.len(), 3); + assert_eq!(row3_str.value(0), "3"); + assert_eq!(row3_str.value(1), "4"); + assert_eq!(row3_str.value(2), "5"); + + Ok(()) + } + + #[test] + fn test_regexp_extract_array() -> Result<()> { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("123-abc"), + Some("456-def"), + None, + Some("no-match"), + ])) as ArrayRef; + + let result = regexp_extract_array(&array, ®ex, 1)?; + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "123"); + assert_eq!(result_array.value(1), "456"); + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + + Ok(()) + } +} + diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 54df2f1688..d18e84ffab 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -153,6 +153,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 15f4b238f2..0756615bd2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RegExpExtract, RegExpExtractAll, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -286,3 +286,101 @@ trait CommonStringExprs { } } } + +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + // Check if the pattern is compatible with Spark + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + // Pattern must be a literal + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal + expr.idx match { + case Literal(_, DataTypes.IntegerType) => Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) + } + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + + val optExpr = scalarFunctionExprToProto( + "regexp_extract", + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + // Check if the pattern is compatible with Spark + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + // Pattern must be a literal + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal if exists + if (expr.idx.isDefined) { + expr.idx.get match { + case Literal(_, DataTypes.IntegerType) => Compatible() + case _ => return Unsupported(Some("Only literal group index is supported")) + } + } + } + + override def convert(expr: RegExpExtractAll, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + + val optExpr = if (expr.idx.isDefined) { + val idxExpr = exprToProtoInternal(expr.idx.get, inputs, binding) + scalarFunctionExprToProto( + "regexp_extract_all", + subjectExpr, + patternExpr, + idxExpr) + } else { + scalarFunctionExprToProto( + "regexp_extract_all", + subjectExpr, + patternExpr) + } + + if (expr.idx.isDefined) { + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx.get) + } else { + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp) + } + } +} \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index f9882780c8..ffa609b8f1 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -391,4 +391,126 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("regexp_extract basic") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("100-200", 1), + ("300-400", 1), + (null, 1), // NULL input + ("no-match", 1), // no match → should return "" + ("abc123def456", 1), + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test basic extraction: group 0 (full match) + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") + // Test group 2 + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") + // Test non-existent group → should return "" + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + } + } + } + + test("regexp_extract edge cases") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("email@example.com", 1), + ("phone: 123-456-7890", 1), + ("price: $99.99", 1), + (null, 1) + ) + + withParquetTable(data, "tbl") { + // Extract email domain + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") + // Extract phone number + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{3}-\\d{3}-\\d{4})', 1) FROM tbl") + // Extract price + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all basic") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("a1b2c3", 1), + ("test123test456", 1), + (null, 1), // NULL input + ("no digits", 1), // no match → should return [] + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test default (group 0) + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + // Test with explicit group 0 + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all multiple matches") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("The prices are $10, $20, and $30", 1), + ("colors: red, green, blue", 1), + ("words: hello world", 1), + (null, 1) + ) + + withParquetTable(data, "tbl") { + // Extract all prices + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") + // Extract all words + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all with dictionary encoding") { + import org.apache.comet.CometConf + + withSQLConf( + CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", + "parquet.enable.dictionary" -> "true") { + // Use repeated values to trigger dictionary encoding + val data = (0 until 1000).map(i => { + val text = if (i % 3 == 0) "a1b2c3" else if (i % 3 == 1) "x5y6" else "no-match" + (text, 1) + }) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + } + } + } + } From 4dbed777f3285af2d7a6c9e3cbc6e6ac1d84d5ed Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 29 Nov 2025 22:09:39 -0800 Subject: [PATCH 2/4] refactor strings.scala --- .../org/apache/comet/serde/strings.scala | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 0756615bd2..a4124048ae 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RegExpExtract, RegExpExtractAll, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -289,7 +289,7 @@ trait CommonStringExprs { object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { override def getSupportLevel(expr: RegExpExtract): SupportLevel = { - // Check if the pattern is compatible with Spark + // Check if the pattern is compatible with Spark or allow incompatible patterns expr.regexp match { case Literal(pattern, DataTypes.StringType) => if (!RegExp.isSupportedPattern(pattern.toString) && @@ -302,13 +302,13 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { return Incompatible() } case _ => - // Pattern must be a literal return Unsupported(Some("Only literal regexp patterns are supported")) } - + // Check if idx is a literal expr.idx match { - case Literal(_, DataTypes.IntegerType) => Compatible() + case Literal(_, DataTypes.IntegerType) => + Compatible() case _ => Unsupported(Some("Only literal group index is supported")) } @@ -321,7 +321,6 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) - val optExpr = scalarFunctionExprToProto( "regexp_extract", subjectExpr, @@ -333,7 +332,7 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { - // Check if the pattern is compatible with Spark + // Check if the pattern is compatible with Spark or allow incompatible patterns expr.regexp match { case Literal(pattern, DataTypes.StringType) => if (!RegExp.isSupportedPattern(pattern.toString) && @@ -346,41 +345,31 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { return Incompatible() } case _ => - // Pattern must be a literal return Unsupported(Some("Only literal regexp patterns are supported")) } - - // Check if idx is a literal if exists - if (expr.idx.isDefined) { - expr.idx.get match { - case Literal(_, DataTypes.IntegerType) => Compatible() - case _ => return Unsupported(Some("Only literal group index is supported")) - } + + // Check if idx is a literal + // For regexp_extract_all, idx will be default to 1 if not specified + expr.idx match { + case Literal(_, DataTypes.IntegerType) => + Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) } } - - override def convert(expr: RegExpExtractAll, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + // Check if the pattern is compatible with Spark or allow incompatible patterns val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) - - val optExpr = if (expr.idx.isDefined) { - val idxExpr = exprToProtoInternal(expr.idx.get, inputs, binding) - scalarFunctionExprToProto( - "regexp_extract_all", - subjectExpr, - patternExpr, - idxExpr) - } else { - scalarFunctionExprToProto( - "regexp_extract_all", - subjectExpr, - patternExpr) - } - - if (expr.idx.isDefined) { - optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx.get) - } else { - optExprWithInfo(optExpr, expr, expr.subject, expr.regexp) - } + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProto( + "regexp_extract_all", + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) } -} \ No newline at end of file +} From f1013628ea975c3bf8ec7fc1f2eefb412482fbaf Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Sun, 30 Nov 2025 21:35:49 -0800 Subject: [PATCH 3/4] test, format and configs --- docs/source/user-guide/latest/configs.md | 2 + .../org/apache/comet/serde/strings.scala | 15 +--- .../comet/CometStringExpressionSuite.scala | 90 +++++++++---------- 3 files changed, 47 insertions(+), 60 deletions(-) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index a1c3212c20..f5638d5cf4 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -291,6 +291,8 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.RLike.enabled` | Enable Comet acceleration for `RLike` | true | | `spark.comet.expression.Rand.enabled` | Enable Comet acceleration for `Rand` | true | | `spark.comet.expression.Randn.enabled` | Enable Comet acceleration for `Randn` | true | +| `spark.comet.expression.RegExpExtract.enabled` | Enable Comet acceleration for `RegExpExtract` | true | +| `spark.comet.expression.RegExpExtractAll.enabled` | Enable Comet acceleration for `RegExpExtractAll` | true | | `spark.comet.expression.RegExpReplace.enabled` | Enable Comet acceleration for `RegExpReplace` | true | | `spark.comet.expression.Remainder.enabled` | Enable Comet acceleration for `Remainder` | true | | `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true | diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index a4124048ae..733c25ec2b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -321,11 +321,7 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) - val optExpr = scalarFunctionExprToProto( - "regexp_extract", - subjectExpr, - patternExpr, - idxExpr) + val optExpr = scalarFunctionExprToProto("regexp_extract", subjectExpr, patternExpr, idxExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) } } @@ -349,7 +345,7 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { } // Check if idx is a literal - // For regexp_extract_all, idx will be default to 1 if not specified + // For regexp_extract_all, idx will default to 0 (group 0, entire match) if not specified expr.idx match { case Literal(_, DataTypes.IntegerType) => Compatible() @@ -365,11 +361,8 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) - val optExpr = scalarFunctionExprToProto( - "regexp_extract_all", - subjectExpr, - patternExpr, - idxExpr) + val optExpr = + scalarFunctionExprToProto("regexp_extract_all", subjectExpr, patternExpr, idxExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) } } diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index ffa609b8f1..5214eb8215 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -393,110 +393,102 @@ class CometStringExpressionSuite extends CometTestBase { test("regexp_extract basic") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("100-200", 1), ("300-400", 1), - (null, 1), // NULL input - ("no-match", 1), // no match → should return "" + (null, 1), // NULL input + ("no-match", 1), // no match → should return "" ("abc123def456", 1), - ("", 1) // empty string + ("", 1) // empty string ) - + withParquetTable(data, "tbl") { // Test basic extraction: group 0 (full match) - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") // Test group 1 - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") // Test group 2 - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") // Test non-existent group → should return "" - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, NULL, 0) FROM tbl") } } } test("regexp_extract edge cases") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { - val data = Seq( - ("email@example.com", 1), - ("phone: 123-456-7890", 1), - ("price: $99.99", 1), - (null, 1) - ) - + val data = + Seq(("email@example.com", 1), ("phone: 123-456-7890", 1), ("price: $99.99", 1), (null, 1)) + withParquetTable(data, "tbl") { // Extract email domain - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") // Extract phone number checkSparkAnswerAndOperator( "SELECT regexp_extract(_1, '(\\d{3}-\\d{3}-\\d{4})', 1) FROM tbl") // Extract price - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") } } } test("regexp_extract_all basic") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("a1b2c3", 1), ("test123test456", 1), - (null, 1), // NULL input - ("no digits", 1), // no match → should return [] - ("", 1) // empty string + (null, 1), // NULL input + ("no digits", 1), // no match → should return [] + ("", 1) // empty string ) - + withParquetTable(data, "tbl") { - // Test default (group 0) - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + // Test with explicit group 0 (full match on no-group pattern) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") // Test with explicit group 0 - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") // Test group 1 - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, NULL, 0) FROM tbl") } } } test("regexp_extract_all multiple matches") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("The prices are $10, $20, and $30", 1), ("colors: red, green, blue", 1), ("words: hello world", 1), - (null, 1) - ) - + (null, 1)) + withParquetTable(data, "tbl") { // Extract all prices - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") // Extract all words - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") } } } test("regexp_extract_all with dictionary encoding") { import org.apache.comet.CometConf - + withSQLConf( CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", "parquet.enable.dictionary" -> "true") { @@ -505,10 +497,10 @@ class CometStringExpressionSuite extends CometTestBase { val text = if (i % 3 == 0) "a1b2c3" else if (i % 3 == 1) "x5y6" else "no-match" (text, 1) }) - + withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") } } } From ff1ebd6b3bebe85c0f393b268660dc8031614bc1 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Mon, 1 Dec 2025 00:05:28 -0800 Subject: [PATCH 4/4] make regexp_extract more align with spark's behavior --- .../src/string_funcs/regexp_extract.rs | 169 ++++++++++++------ 1 file changed, 110 insertions(+), 59 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 2a9ce6b82c..eba2e7993c 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -22,8 +22,8 @@ use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use regex::Regex; -use std::sync::Arc; use std::any::Any; +use std::sync::Arc; /// Spark-compatible regexp_extract function #[derive(Debug, PartialEq, Eq, Hash)] @@ -106,8 +106,8 @@ impl ScalarUDFImpl for SparkRegExpExtract { } ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { let result = match s { - Some(text) => Some(extract_group(text, ®ex, idx_val)), - None => None, // NULL input → NULL output + Some(text) => Some(extract_group(text, ®ex, idx_val)?), + None => None, // NULL input → NULL output }; Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) } @@ -157,9 +157,11 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { fn return_type(&self, _arg_types: &[DataType]) -> Result { // regexp_extract_all returns Array - Ok(DataType::List(Arc::new( - arrow::datatypes::Field::new("item", DataType::Utf8, true), - ))) + Ok(DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::Utf8, + false, + )))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -181,7 +183,8 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } } } else { - 0 // default to group 0 (entire match) + // Using 1 here to align with Spark's default behavior. + 1 }; // Pattern must be a literal string @@ -205,7 +208,7 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { match s { Some(text) => { - let matches = extract_all_groups(text, ®ex, idx_val); + let matches = extract_all_groups(text, ®ex, idx_val)?; // Build a list array with a single element let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); @@ -214,16 +217,17 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } list_builder.append(true); let list_array = list_builder.finish(); - - Ok(ColumnarValue::Scalar(ScalarValue::List( - Arc::new(list_array), - ))) + + Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( + list_array, + )))) } None => { // Return NULL list using try_into (same as planner.rs:424) let null_list: ScalarValue = DataType::List(Arc::new( - arrow::datatypes::Field::new("item", DataType::Utf8, true) - )).try_into()?; + arrow::datatypes::Field::new("item", DataType::Utf8, false), + )) + .try_into()?; Ok(ColumnarValue::Scalar(null_list)) } } @@ -239,53 +243,86 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { // Helper functions -fn extract_group(text: &str, regex: &Regex, idx: usize) -> String { - regex - .captures(text) - .and_then(|caps| caps.get(idx)) - .map(|m| m.as_str().to_string()) - // Spark behavior: return empty string "" if no match or group not found - .unwrap_or_else(|| String::new()) +fn extract_group(text: &str, regex: &Regex, idx: usize) -> Result { + match regex.captures(text) { + Some(caps) => { + // Spark behavior: throw error if group index is out of bounds + if idx >= caps.len() { + return exec_err!( + "Regex group count is {}, but the specified group index is {}", + caps.len(), + idx + ); + } + Ok(caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default()) + } + None => { + // No match: return empty string (Spark behavior) + Ok(String::new()) + } + } } fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| { - internal_datafusion_err!("regexp_extract expects string array input") - })?; + .ok_or_else(|| internal_datafusion_err!("regexp_extract expects string array input"))?; - let result: GenericStringArray = string_array - .iter() - .map(|s| s.map(|text| extract_group(text, regex, idx))) // NULL → None, non-NULL → Some("") - .collect(); + let mut builder = arrow::array::StringBuilder::new(); + for s in string_array.iter() { + match s { + Some(text) => { + let extracted = extract_group(text, regex, idx)?; + builder.append_value(extracted); + } + None => { + builder.append_null(); // NULL → None + } + } + } - Ok(Arc::new(result)) + Ok(Arc::new(builder.finish())) } -fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Vec { - regex - .captures_iter(text) - .filter_map(|caps| caps.get(idx).map(|m| m.as_str().to_string())) - .collect() +fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Result> { + let mut results = Vec::new(); + + for caps in regex.captures_iter(text) { + // Check bounds for each capture (matches Spark behavior) + if idx >= caps.len() { + return exec_err!( + "Regex group count is {}, but the specified group index is {}", + caps.len(), + idx + ); + } + + let matched = caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(); + results.push(matched); + } + + Ok(results) } fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| { - internal_datafusion_err!("regexp_extract_all expects string array input") - })?; + .ok_or_else(|| internal_datafusion_err!("regexp_extract_all expects string array input"))?; - let mut list_builder = - arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); for s in string_array.iter() { match s { Some(text) => { - let matches = extract_all_groups(text, regex, idx); + let matches = extract_all_groups(text, regex, idx)?; for m in matches { list_builder.values().append_value(m); } @@ -310,11 +347,14 @@ mod tests { let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); // Spark behavior: return "" on no match, not None - assert_eq!(extract_group("123-abc", ®ex, 0), "123-abc"); - assert_eq!(extract_group("123-abc", ®ex, 1), "123"); - assert_eq!(extract_group("123-abc", ®ex, 2), "abc"); - assert_eq!(extract_group("123-abc", ®ex, 3), ""); // no such group → "" - assert_eq!(extract_group("no match", ®ex, 0), ""); // no match → "" + assert_eq!(extract_group("123-abc", ®ex, 0).unwrap(), "123-abc"); + assert_eq!(extract_group("123-abc", ®ex, 1).unwrap(), "123"); + assert_eq!(extract_group("123-abc", ®ex, 2).unwrap(), "abc"); + assert_eq!(extract_group("no match", ®ex, 0).unwrap(), ""); // no match → "" + + // Spark behavior: group index out of bounds → error + assert!(extract_group("123-abc", ®ex, 3).is_err()); + assert!(extract_group("123-abc", ®ex, 99).is_err()); } #[test] @@ -322,23 +362,26 @@ mod tests { let regex = Regex::new(r"(\d+)").unwrap(); // Multiple matches - let matches = extract_all_groups("a1b2c3", ®ex, 0); + let matches = extract_all_groups("a1b2c3", ®ex, 0).unwrap(); assert_eq!(matches, vec!["1", "2", "3"]); // Same with group index 1 - let matches = extract_all_groups("a1b2c3", ®ex, 1); + let matches = extract_all_groups("a1b2c3", ®ex, 1).unwrap(); assert_eq!(matches, vec!["1", "2", "3"]); - // No match - let matches = extract_all_groups("no digits", ®ex, 0); + // No match: returns empty vec, not error + let matches = extract_all_groups("no digits", ®ex, 0).unwrap(); assert!(matches.is_empty()); assert_eq!(matches, Vec::::new()); + + // Group index out of bounds → error + assert!(extract_all_groups("a1b2c3", ®ex, 2).is_err()); } - + #[test] fn test_regexp_extract_all_array() -> Result<()> { use datafusion::common::cast::as_list_array; - + let regex = Regex::new(r"(\d+)").unwrap(); let array = Arc::new(StringArray::from(vec![ Some("a1b2"), @@ -352,23 +395,32 @@ mod tests { // Row 0: "a1b2" → ["1", "2"] let row0 = list_array.value(0); - let row0_str = row0.as_any().downcast_ref::>().unwrap(); + let row0_str = row0 + .as_any() + .downcast_ref::>() + .unwrap(); assert_eq!(row0_str.len(), 2); assert_eq!(row0_str.value(0), "1"); assert_eq!(row0_str.value(1), "2"); // Row 1: "no digits" → [] (empty array, not NULL) let row1 = list_array.value(1); - let row1_str = row1.as_any().downcast_ref::>().unwrap(); - assert_eq!(row1_str.len(), 0); // Empty array - assert!(!list_array.is_null(1)); // Not NULL, just empty + let row1_str = row1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row1_str.len(), 0); // Empty array + assert!(!list_array.is_null(1)); // Not NULL, just empty // Row 2: NULL input → NULL output assert!(list_array.is_null(2)); // Row 3: "c3d4e5" → ["3", "4", "5"] let row3 = list_array.value(3); - let row3_str = row3.as_any().downcast_ref::>().unwrap(); + let row3_str = row3 + .as_any() + .downcast_ref::>() + .unwrap(); assert_eq!(row3_str.len(), 3); assert_eq!(row3_str.value(0), "3"); assert_eq!(row3_str.value(1), "4"); @@ -392,10 +444,9 @@ mod tests { assert_eq!(result_array.value(0), "123"); assert_eq!(result_array.value(1), "456"); - assert!(result_array.is_null(2)); // NULL input → NULL output - assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) Ok(()) } } -