diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index a1c3212c20..f5638d5cf4 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -291,6 +291,8 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.RLike.enabled` | Enable Comet acceleration for `RLike` | true | | `spark.comet.expression.Rand.enabled` | Enable Comet acceleration for `Rand` | true | | `spark.comet.expression.Randn.enabled` | Enable Comet acceleration for `Randn` | true | +| `spark.comet.expression.RegExpExtract.enabled` | Enable Comet acceleration for `RegExpExtract` | true | +| `spark.comet.expression.RegExpExtractAll.enabled` | Enable Comet acceleration for `RegExpExtractAll` | true | | `spark.comet.expression.RegExpReplace.enabled` | Enable Comet acceleration for `RegExpReplace` | true | | `spark.comet.expression.Remainder.enabled` | Enable Comet acceleration for `Remainder` | true | | `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true | diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 021bb1c78f..45b4ca8aad 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,7 +23,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, - SparkDateTrunc, SparkStringSpace, + SparkDateTrunc, SparkRegExpExtract, SparkRegExpExtractAll, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -199,6 +199,8 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtract::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtractAll::default())), ] } diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..2026ec5fec 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod regexp_extract; mod string_space; mod substring; +pub use regexp_extract::{SparkRegExpExtract, SparkRegExpExtractAll}; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs new file mode 100644 index 0000000000..eba2e7993c --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -0,0 +1,452 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericStringArray}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use regex::Regex; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible regexp_extract function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtract { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkRegExpExtract { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // regexp_extract always returns String + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // regexp_extract(subject, pattern, idx) + if args.args.len() != 3 { + return exec_err!( + "regexp_extract expects 3 arguments, got {}", + args.args.len() + ); + } + + let subject = &args.args[0]; + let pattern = &args.args[1]; + let idx = &args.args[2]; + + // Pattern must be a literal string + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("regexp_extract pattern must be a string literal"); + } + }; + + // idx must be a literal int + let idx_val = match idx { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, + _ => { + return exec_err!("regexp_extract idx must be an integer literal"); + } + }; + + // Compile regex once + let regex = Regex::new(&pattern_str).map_err(|e| { + internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) + })?; + + match subject { + ColumnarValue::Array(array) => { + let result = regexp_extract_array(array, ®ex, idx_val)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + let result = match s { + Some(text) => Some(extract_group(text, ®ex, idx_val)?), + None => None, // NULL input → NULL output + }; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => exec_err!("regexp_extract expects string input"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Spark-compatible regexp_extract_all function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtractAll { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkRegExpExtractAll { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtractAll { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtractAll { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // regexp_extract_all returns Array + Ok(DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::Utf8, + false, + )))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // regexp_extract_all(subject, pattern) or regexp_extract_all(subject, pattern, idx) + if args.args.len() < 2 || args.args.len() > 3 { + return exec_err!( + "regexp_extract_all expects 2 or 3 arguments, got {}", + args.args.len() + ); + } + + let subject = &args.args[0]; + let pattern = &args.args[1]; + let idx_val = if args.args.len() == 3 { + match &args.args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, + _ => { + return exec_err!("regexp_extract_all idx must be an integer literal"); + } + } + } else { + // Using 1 here to align with Spark's default behavior. + 1 + }; + + // Pattern must be a literal string + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("regexp_extract_all pattern must be a string literal"); + } + }; + + // Compile regex once + let regex = Regex::new(&pattern_str).map_err(|e| { + internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) + })?; + + match subject { + ColumnarValue::Array(array) => { + let result = regexp_extract_all_array(array, ®ex, idx_val)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + match s { + Some(text) => { + let matches = extract_all_groups(text, ®ex, idx_val)?; + // Build a list array with a single element + let mut list_builder = + arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + let list_array = list_builder.finish(); + + Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( + list_array, + )))) + } + None => { + // Return NULL list using try_into (same as planner.rs:424) + let null_list: ScalarValue = DataType::List(Arc::new( + arrow::datatypes::Field::new("item", DataType::Utf8, false), + )) + .try_into()?; + Ok(ColumnarValue::Scalar(null_list)) + } + } + } + _ => exec_err!("regexp_extract_all expects string input"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +// Helper functions + +fn extract_group(text: &str, regex: &Regex, idx: usize) -> Result { + match regex.captures(text) { + Some(caps) => { + // Spark behavior: throw error if group index is out of bounds + if idx >= caps.len() { + return exec_err!( + "Regex group count is {}, but the specified group index is {}", + caps.len(), + idx + ); + } + Ok(caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default()) + } + None => { + // No match: return empty string (Spark behavior) + Ok(String::new()) + } + } +} + +fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| internal_datafusion_err!("regexp_extract expects string array input"))?; + + let mut builder = arrow::array::StringBuilder::new(); + for s in string_array.iter() { + match s { + Some(text) => { + let extracted = extract_group(text, regex, idx)?; + builder.append_value(extracted); + } + None => { + builder.append_null(); // NULL → None + } + } + } + + Ok(Arc::new(builder.finish())) +} + +fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Result> { + let mut results = Vec::new(); + + for caps in regex.captures_iter(text) { + // Check bounds for each capture (matches Spark behavior) + if idx >= caps.len() { + return exec_err!( + "Regex group count is {}, but the specified group index is {}", + caps.len(), + idx + ); + } + + let matched = caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(); + results.push(matched); + } + + Ok(results) +} + +fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| internal_datafusion_err!("regexp_extract_all expects string array input"))?; + + let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + + for s in string_array.iter() { + match s { + Some(text) => { + let matches = extract_all_groups(text, regex, idx)?; + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + } + None => { + list_builder.append(false); + } + } + } + + Ok(Arc::new(list_builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_regexp_extract_basic() { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + + // Spark behavior: return "" on no match, not None + assert_eq!(extract_group("123-abc", ®ex, 0).unwrap(), "123-abc"); + assert_eq!(extract_group("123-abc", ®ex, 1).unwrap(), "123"); + assert_eq!(extract_group("123-abc", ®ex, 2).unwrap(), "abc"); + assert_eq!(extract_group("no match", ®ex, 0).unwrap(), ""); // no match → "" + + // Spark behavior: group index out of bounds → error + assert!(extract_group("123-abc", ®ex, 3).is_err()); + assert!(extract_group("123-abc", ®ex, 99).is_err()); + } + + #[test] + fn test_regexp_extract_all_basic() { + let regex = Regex::new(r"(\d+)").unwrap(); + + // Multiple matches + let matches = extract_all_groups("a1b2c3", ®ex, 0).unwrap(); + assert_eq!(matches, vec!["1", "2", "3"]); + + // Same with group index 1 + let matches = extract_all_groups("a1b2c3", ®ex, 1).unwrap(); + assert_eq!(matches, vec!["1", "2", "3"]); + + // No match: returns empty vec, not error + let matches = extract_all_groups("no digits", ®ex, 0).unwrap(); + assert!(matches.is_empty()); + assert_eq!(matches, Vec::::new()); + + // Group index out of bounds → error + assert!(extract_all_groups("a1b2c3", ®ex, 2).is_err()); + } + + #[test] + fn test_regexp_extract_all_array() -> Result<()> { + use datafusion::common::cast::as_list_array; + + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("a1b2"), + Some("no digits"), + None, + Some("c3d4e5"), + ])) as ArrayRef; + + let result = regexp_extract_all_array(&array, ®ex, 0)?; + let list_array = as_list_array(&result)?; + + // Row 0: "a1b2" → ["1", "2"] + let row0 = list_array.value(0); + let row0_str = row0 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row0_str.len(), 2); + assert_eq!(row0_str.value(0), "1"); + assert_eq!(row0_str.value(1), "2"); + + // Row 1: "no digits" → [] (empty array, not NULL) + let row1 = list_array.value(1); + let row1_str = row1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row1_str.len(), 0); // Empty array + assert!(!list_array.is_null(1)); // Not NULL, just empty + + // Row 2: NULL input → NULL output + assert!(list_array.is_null(2)); + + // Row 3: "c3d4e5" → ["3", "4", "5"] + let row3 = list_array.value(3); + let row3_str = row3 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row3_str.len(), 3); + assert_eq!(row3_str.value(0), "3"); + assert_eq!(row3_str.value(1), "4"); + assert_eq!(row3_str.value(2), "5"); + + Ok(()) + } + + #[test] + fn test_regexp_extract_array() -> Result<()> { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("123-abc"), + Some("456-def"), + None, + Some("no-match"), + ])) as ArrayRef; + + let result = regexp_extract_array(&array, ®ex, 1)?; + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "123"); + assert_eq!(result_array.value(1), "456"); + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + + Ok(()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 54df2f1688..d18e84ffab 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -153,6 +153,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 15f4b238f2..733c25ec2b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -286,3 +286,83 @@ trait CommonStringExprs { } } } + +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + // Check if the pattern is compatible with Spark or allow incompatible patterns + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal + expr.idx match { + case Literal(_, DataTypes.IntegerType) => + Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) + } + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProto("regexp_extract", subjectExpr, patternExpr, idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + // Check if the pattern is compatible with Spark or allow incompatible patterns + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal + // For regexp_extract_all, idx will default to 0 (group 0, entire match) if not specified + expr.idx match { + case Literal(_, DataTypes.IntegerType) => + Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) + } + } + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + // Check if the pattern is compatible with Spark or allow incompatible patterns + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = + scalarFunctionExprToProto("regexp_extract_all", subjectExpr, patternExpr, idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index f9882780c8..5214eb8215 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -391,4 +391,118 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("regexp_extract basic") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("100-200", 1), + ("300-400", 1), + (null, 1), // NULL input + ("no-match", 1), // no match → should return "" + ("abc123def456", 1), + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test basic extraction: group 0 (full match) + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") + // Test group 2 + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") + // Test non-existent group → should return "" + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, NULL, 0) FROM tbl") + } + } + } + + test("regexp_extract edge cases") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = + Seq(("email@example.com", 1), ("phone: 123-456-7890", 1), ("price: $99.99", 1), (null, 1)) + + withParquetTable(data, "tbl") { + // Extract email domain + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") + // Extract phone number + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{3}-\\d{3}-\\d{4})', 1) FROM tbl") + // Extract price + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all basic") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("a1b2c3", 1), + ("test123test456", 1), + (null, 1), // NULL input + ("no digits", 1), // no match → should return [] + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test with explicit group 0 (full match on no-group pattern) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") + // Test with explicit group 0 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, NULL, 0) FROM tbl") + } + } + } + + test("regexp_extract_all multiple matches") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("The prices are $10, $20, and $30", 1), + ("colors: red, green, blue", 1), + ("words: hello world", 1), + (null, 1)) + + withParquetTable(data, "tbl") { + // Extract all prices + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") + // Extract all words + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all with dictionary encoding") { + import org.apache.comet.CometConf + + withSQLConf( + CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", + "parquet.enable.dictionary" -> "true") { + // Use repeated values to trigger dictionary encoding + val data = (0 until 1000).map(i => { + val text = if (i % 3 == 0) "a1b2c3" else if (i % 3 == 1) "x5y6" else "no-match" + (text, 1) + }) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") + } + } + } + }