diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 60e2234f59..2cafe2a640 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -159,6 +159,9 @@ The following cast operations are generally compatible with Spark except for the | string | short | | | string | integer | | | string | long | | +| string | float | | +| string | double | | +| string | decimal | | | string | binary | | | string | date | Only supports years between 262143 BC and 262142 AD | | binary | string | | @@ -181,9 +184,6 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | float | decimal | There can be rounding differences | | double | decimal | There can be rounding differences | -| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | -| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | -| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits | | string | timestamp | Not all valid formats are supported | diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 12a147c6e1..41f66c143b 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -19,14 +19,9 @@ use crate::utils::array_with_timezone; use crate::{timezone, BinaryOutputStyle}; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; -use arrow::array::{ - BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray, - StructArray, -}; +use arrow::array::{ArrayAccessor, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, PrimitiveBuilder, StringArray, StructArray}; use arrow::compute::can_cast_types; -use arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema, -}; +use arrow::datatypes::{i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, DecimalType, GenericBinaryType, Schema}; use arrow::{ array::{ cast::AsArray, @@ -44,6 +39,7 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; +use base64::prelude::*; use chrono::{DateTime, NaiveDate, TimeZone, Timelike}; use datafusion::common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult, @@ -56,6 +52,7 @@ use num::{ ToPrimitive, Zero, }; use regex::Regex; +use std::num::ParseFloatError; use std::str::FromStr; use std::{ any::Any, @@ -65,8 +62,6 @@ use std::{ sync::Arc, }; -use base64::prelude::*; - static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); const MICROS_PER_SECOND: i64 = 1000000; @@ -216,19 +211,9 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool use DataType::*; match to_type { Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true, - Float32 | Float64 => { - // https://github.com/apache/datafusion-comet/issues/326 - // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. - // Does not support ANSI mode. - options.allow_incompat - } - Decimal128(_, _) => { - // https://github.com/apache/datafusion-comet/issues/325 - // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. - // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits - - options.allow_incompat - } + Float32 | Float64 => true, + Decimal128(_, _) => true, + Decimal256(_, _) => true, Date32 | Date64 => { // https://github.com/apache/datafusion-comet/issues/327 // Only supports years between 262143 BC and 262142 AD @@ -976,6 +961,13 @@ fn cast_array( cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Utf8, Float16 | Float32 | Float64) => cast_string_to_float(&array, to_type, eval_mode), + (Utf8 | LargeUtf8, Decimal128(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } + (Utf8 | LargeUtf8, Decimal256(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } (Int64, Int32) | (Int64, Int16) | (Int64, Int8) @@ -1058,6 +1050,363 @@ fn cast_array( Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) } +fn cast_string_to_decimal( + array: &ArrayRef, + to_type: &DataType, + precision: &u8, + scale: &i8, + eval_mode: EvalMode, +) -> SparkResult { + match to_type { + DataType::Decimal128(_, _) => { + cast_string_to_decimal128_impl(array, eval_mode, *precision, *scale) + } + DataType::Decimal256(_, _) => { + cast_string_to_decimal256_impl(array, eval_mode, *precision, *scale) + } + _ => Err(SparkError::Internal(format!( + "Unexpected type in cast_string_to_decimal: {:?}", + to_type + ))), + } +} + +fn cast_string_to_decimal128_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i).trim(); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + decimal_builder.append_value(decimal_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value( + str_value, + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +fn cast_string_to_decimal256_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut decimal_builder = PrimitiveBuilder::::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i).trim(); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + // Convert i128 to i256 + let i256_value = i256::from_i128(decimal_value); + decimal_builder.append_value(i256_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value( + str_value, + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +/// Parse a string to decimal following Spark's behavior +/// Returns Ok(Some(value)) if successful, Ok(None) if null, Err if invalid in ANSI mode +fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult> { + if s.is_empty() { + return Ok(None); + } + + // Handle special values (inf, nan, etc.) + let s_lower = s.to_lowercase(); + if s_lower == "inf" + || s_lower == "+inf" + || s_lower == "infinity" + || s_lower == "+infinity" + || s_lower == "-inf" + || s_lower == "-infinity" + || s_lower == "nan" + { + return Ok(None); + } + + // Parse the string as a decimal number + // Note: We do NOT strip 'D' or 'F' suffixes - let parsing fail naturally + // This matches Spark's behavior which uses JavaBigDecimal(string) + match parse_decimal_str(s) { + Ok((mantissa, exponent)) => { + // Convert to target scale + let target_scale = scale as i32; + let scale_adjustment = target_scale - exponent; + + let scaled_value = if scale_adjustment >= 0 { + // Need to multiply (increase scale) + mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) + } else { + // Need to divide (decrease scale) - use rounding half up + let divisor = 10_i128.pow((-scale_adjustment) as u32); + let quotient = mantissa / divisor; + let remainder = mantissa % divisor; + + // Round half up: if abs(remainder) >= divisor/2, round away from zero + let half_divisor = divisor / 2; + let rounded = if remainder.abs() >= half_divisor { + if mantissa >= 0 { + quotient + 1 + } else { + quotient - 1 + } + } else { + quotient + }; + Some(rounded) + }; + + match scaled_value { + Some(value) => { + // Check if it fits target precision + if is_validate_decimal_precision(value, precision) { + Ok(Some(value)) + } else { + // Overflow + Ok(None) + } + } + None => { + // Overflow during scaling + Ok(None) + } + } + } + Err(_) => Ok(None), + } +} + +/// Parse a decimal string into (mantissa, scale) +/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) +fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { + let s = s.trim(); + if s.is_empty() { + return Err("Empty string".to_string()); + } + + let negative = s.starts_with('-'); + let s = if negative || s.starts_with('+') { + &s[1..] + } else { + s + }; + + // Split by decimal point + let parts: Vec<&str> = s.split('.').collect(); + + if parts.len() > 2 { + return Err("Multiple decimal points".to_string()); + } + + let integral_part = parts[0]; + let fractional_part = if parts.len() == 2 { parts[1] } else { "" }; + + // Parse integral part + let integral_value: i128 = if integral_part.is_empty() { + 0 + } else { + integral_part + .parse() + .map_err(|_| "Invalid integral part".to_string())? + }; + + // Parse fractional part + let scale = fractional_part.len() as i32; + let fractional_value: i128 = if fractional_part.is_empty() { + 0 + } else { + fractional_part + .parse() + .map_err(|_| "Invalid fractional part".to_string())? + }; + + // Combine: value = integral * 10^scale + fractional + let mantissa = integral_value + .checked_mul(10_i128.pow(scale as u32)) + .and_then(|v| v.checked_add(fractional_value)) + .ok_or("Overflow in mantissa calculation")?; + + let final_mantissa = if negative { -mantissa } else { mantissa }; + + Ok((final_mantissa, scale)) +} + +fn cast_string_to_float( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + match to_type { + DataType::Float32 => cast_string_to_float_impl::(array, eval_mode, "FLOAT"), + DataType::Float64 => cast_string_to_float_impl::(array, eval_mode, "DOUBLE"), + _ => Err(SparkError::Internal(format!( + "Unsupported cast to float type: {:?}", + to_type + ))), + } +} + +fn cast_string_to_float_impl( + array: &ArrayRef, + eval_mode: EvalMode, + type_name: &str, +) -> SparkResult +where + T::Native: FloatParse, +{ + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("could not parse input as string type".to_string()))?; + + let mut cast_array = PrimitiveArray::::builder(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + cast_array.append_null(); + } else { + let str_value = arr.value(i).trim(); + match T::Native::parse_spark_float(str_value) { + Ok(v) => { + cast_array.append_value(v); + } + Err(_) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value(arr.value(i), "STRING", type_name)); + } else { + cast_array.append_null(); + } + } + } + } + } + Ok(Arc::new(cast_array.finish())) +} + +/// Trait for parsing float from str +trait FloatParse: Sized { + fn parse_spark_float(s: &str) -> Result; +} + +impl FloatParse for f32 { + fn parse_spark_float(s: &str) -> Result { + let s_lower = s.to_lowercase(); + + if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity" + { + return Ok(f32::INFINITY); + } + + if s_lower == "-inf" || s_lower == "-infinity" { + return Ok(f32::NEG_INFINITY); + } + + if s_lower == "nan" { + return Ok(f32::NAN); + } + + let pruned = if s_lower.ends_with('d') || s_lower.ends_with('f') { + &s[..s.len() - 1] + } else { + s + }; + pruned.parse::() + } +} + +impl FloatParse for f64 { + fn parse_spark_float(s: &str) -> Result { + let s_lower = s.to_lowercase(); + + if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity" + { + return Ok(f64::INFINITY); + } + + if s_lower == "-inf" || s_lower == "-infinity" { + return Ok(f64::NEG_INFINITY); + } + + if s_lower == "nan" { + return Ok(f64::NAN); + } + + let cleaned = if s_lower.ends_with('d') || s_lower.ends_with('f') { + &s[..s.len() - 1] + } else { + s + }; + cleaned.parse::() + } +} + fn cast_binary_to_string( array: &dyn Array, spark_cast_options: &SparkCastOptions, @@ -1185,11 +1534,13 @@ fn is_datafusion_spark_compatible( | DataType::Decimal256(_, _) | DataType::Utf8 // note that there can be formatting differences ), - DataType::Utf8 if allow_incompat => matches!( + DataType::Utf8 if allow_incompat => { + matches!(to_type, DataType::Binary | DataType::Decimal128(_, _)) + } + DataType::Utf8 => matches!( to_type, - DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) + DataType::Binary | DataType::Float32 | DataType::Float64 ), - DataType::Utf8 => matches!(to_type, DataType::Binary), DataType::Date32 => matches!(to_type, DataType::Utf8), DataType::Timestamp(_, _) => { matches!( diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 98ce8ac44d..4b16242305 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -185,16 +185,9 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case DataTypes.BinaryType => Compatible() case DataTypes.FloatType | DataTypes.DoubleType => - // https://github.com/apache/datafusion-comet/issues/326 - Incompatible( - Some( - "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + - "Does not support ANSI mode.")) + Compatible() case _: DecimalType => - // https://github.com/apache/datafusion-comet/issues/325 - Incompatible( - Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + - "Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) + Compatible() case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1912e982b9..af5f4d82fe 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -652,35 +652,42 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType) } - ignore("cast StringType to FloatType") { - // https://github.com/apache/datafusion-comet/issues/326 - castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType) - } + def specialValues: Seq[String] = Seq( + "1.5f", + "1.5F", + "2.0d", + "2.0D", + "3.14159265358979d", + "inf", + "Inf", + "INF", + "+inf", + "+Infinity", + "-inf", + "-Infinity", + "NaN", + "nan", + "NAN", + "1.23e4", + "1.23E4", + "-1.23e-4", + " 123.456789 ", + "0.0", + "-0.0", + "", + "xyz", + null) - test("cast StringType to FloatType (partial support)") { - withSQLConf( - CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - castTest( - gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"), - DataTypes.FloatType, - testAnsi = false) + test("cast StringType to FloatType") { + Seq(true, false).foreach { v => + castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = v) } - } - ignore("cast StringType to DoubleType") { - // https://github.com/apache/datafusion-comet/issues/326 - castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType) } - test("cast StringType to DoubleType (partial support)") { - withSQLConf( - CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - castTest( - gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"), - DataTypes.DoubleType, - testAnsi = false) + test("cast StringType to DoubleType") { + Seq(true, false).foreach { v => + castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = v) } } @@ -690,6 +697,41 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(values, DataTypes.createDecimalType(10, 2)) } + test("cast StringType to DecimalType(10,2) basic values") { + val values = Seq( + "123.45", + "-67.89", + "0.001", + "999.99", + "123.456", + "123.45D", + ".5", + "5.", + "+123.45", + " 123.45 ", + "inf", + "", + "abc", + null).toDF("a") + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) + } + + test("cast StringType to DecimalType(38,10) high precision") { + val values = Seq( + "123.45", + "-67.89", + "9999999999999999999999999999.9999999999", + "-9999999999999999999999999999.9999999999", + "0.0000000001", + "123456789012345678.1234567890", + "123.456", + "inf", + "", + "abc", + null).toDF("a") + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = false) + } + test("cast StringType to DecimalType(10,2) (partial support)") { withSQLConf( CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",