Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions datafusion/spark/src/function/hash/crc32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ use arrow::array::{ArrayRef, Int64Array};
use arrow::datatypes::DataType;
use crc32fast::Hasher;
use datafusion_common::cast::{
as_binary_array, as_binary_view_array, as_large_binary_array,
as_binary_array, as_binary_view_array, as_fixed_size_binary_array,
as_large_binary_array,
};
use datafusion_common::{exec_err, internal_err, Result};
use datafusion_common::types::{logical_string, NativeType};
use datafusion_common::utils::take_function_args;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_functions::utils::make_scalar_function;

Expand All @@ -45,7 +49,14 @@ impl Default for SparkCrc32 {
impl SparkCrc32 {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
signature: Signature::coercible(
vec![Coercion::new_implicit(
TypeSignatureClass::Binary,
vec![TypeSignatureClass::Native(logical_string())],
NativeType::Binary,
)],
Volatility::Immutable,
),
}
}
}
Expand All @@ -70,24 +81,6 @@ impl ScalarUDFImpl for SparkCrc32 {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(spark_crc32, vec![])(&args.args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!(
"`crc32` function requires 1 argument, got {}",
arg_types.len()
);
}
match arg_types[0] {
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
Ok(vec![arg_types[0].clone()])
}
DataType::Utf8 | DataType::Utf8View => Ok(vec![DataType::Binary]),
DataType::LargeUtf8 => Ok(vec![DataType::LargeBinary]),
DataType::Null => Ok(vec![DataType::Binary]),
_ => exec_err!("`crc32` function does not support type {}", arg_types[0]),
}
}
}

fn spark_crc32_digest(value: &[u8]) -> i64 {
Expand All @@ -104,14 +97,10 @@ fn spark_crc32_impl<'a>(input: impl Iterator<Item = Option<&'a [u8]>>) -> ArrayR
}

fn spark_crc32(args: &[ArrayRef]) -> Result<ArrayRef> {
let [input] = args else {
return internal_err!(
"Spark `crc32` function requires 1 argument, got {}",
args.len()
);
};
let [input] = take_function_args("crc32", args)?;

match input.data_type() {
DataType::Null => Ok(Arc::new(Int64Array::new_null(input.len()))),
DataType::Binary => {
let input = as_binary_array(input)?;
Ok(spark_crc32_impl(input.iter()))
Expand All @@ -124,11 +113,12 @@ fn spark_crc32(args: &[ArrayRef]) -> Result<ArrayRef> {
let input = as_binary_view_array(input)?;
Ok(spark_crc32_impl(input.iter()))
}
_ => {
exec_err!(
"Spark `crc32` function: argument must be binary or large binary, got {:?}",
input.data_type()
)
DataType::FixedSizeBinary(_) => {
let input = as_fixed_size_binary_array(input)?;
Ok(spark_crc32_impl(input.iter()))
}
dt => {
internal_err!("Unsupported data type for crc32: {dt}")
}
}
}
57 changes: 23 additions & 34 deletions datafusion/spark/src/function/hash/sha1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::DataType;
use datafusion_common::cast::{
as_binary_array, as_binary_view_array, as_large_binary_array,
as_binary_array, as_binary_view_array, as_fixed_size_binary_array,
as_large_binary_array,
};
use datafusion_common::{exec_err, internal_err, Result};
use datafusion_common::types::{logical_string, NativeType};
use datafusion_common::utils::take_function_args;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_functions::utils::make_scalar_function;
use sha1::{Digest, Sha1};
Expand All @@ -47,7 +51,14 @@ impl Default for SparkSha1 {
impl SparkSha1 {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
signature: Signature::coercible(
vec![Coercion::new_implicit(
TypeSignatureClass::Binary,
vec![TypeSignatureClass::Native(logical_string())],
NativeType::Binary,
)],
Volatility::Immutable,
),
aliases: vec!["sha".to_string()],
}
}
Expand Down Expand Up @@ -77,32 +88,13 @@ impl ScalarUDFImpl for SparkSha1 {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(spark_sha1, vec![])(&args.args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!(
"`sha1` function requires 1 argument, got {}",
arg_types.len()
);
}
match arg_types[0] {
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
Ok(vec![arg_types[0].clone()])
}
DataType::Utf8 | DataType::Utf8View => Ok(vec![DataType::Binary]),
DataType::LargeUtf8 => Ok(vec![DataType::LargeBinary]),
DataType::Null => Ok(vec![DataType::Binary]),
_ => exec_err!("`sha1` function does not support type {}", arg_types[0]),
}
}
}

fn spark_sha1_digest(value: &[u8]) -> String {
let result = Sha1::digest(value);
let mut s = String::with_capacity(result.len() * 2);
#[allow(deprecated)]
for b in result.as_slice() {
#[allow(clippy::unwrap_used)]
write!(&mut s, "{b:02x}").unwrap();
}
s
Expand All @@ -116,14 +108,10 @@ fn spark_sha1_impl<'a>(input: impl Iterator<Item = Option<&'a [u8]>>) -> ArrayRe
}

fn spark_sha1(args: &[ArrayRef]) -> Result<ArrayRef> {
let [input] = args else {
return internal_err!(
"Spark `sha1` function requires 1 argument, got {}",
args.len()
);
};
let [input] = take_function_args("sha1", args)?;

match input.data_type() {
DataType::Null => Ok(Arc::new(StringArray::new_null(input.len()))),
DataType::Binary => {
let input = as_binary_array(input)?;
Ok(spark_sha1_impl(input.iter()))
Expand All @@ -136,11 +124,12 @@ fn spark_sha1(args: &[ArrayRef]) -> Result<ArrayRef> {
let input = as_binary_view_array(input)?;
Ok(spark_sha1_impl(input.iter()))
}
_ => {
exec_err!(
"Spark `sha1` function: argument must be binary or large binary, got {:?}",
input.data_type()
)
DataType::FixedSizeBinary(_) => {
let input = as_fixed_size_binary_array(input)?;
Ok(spark_sha1_impl(input.iter()))
}
dt => {
internal_err!("Unsupported data type for sha1: {dt}")
}
}
}
26 changes: 16 additions & 10 deletions datafusion/sqllogictest/test_files/spark/hash/crc32.slt
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ SELECT crc32('Spark');
----
1557323817

query I
SELECT crc32(NULL);
----
NULL

query I
SELECT crc32('');
----
Expand All @@ -45,31 +40,42 @@ SELECT crc32(arrow_cast('', 'Binary'));
----
0

# Test with LargeUtf8 (using CAST to ensure type)
# Test with different types
query I
SELECT crc32(NULL);
----
NULL

query I
SELECT crc32(arrow_cast('Spark', 'LargeUtf8'));
----
1557323817

# Test with Utf8View (using CAST to ensure type)
query I
SELECT crc32(arrow_cast('Spark', 'Utf8View'));
----
1557323817

# Test with different binary types
query I
SELECT crc32(arrow_cast('Spark', 'Utf8'));
----
1557323817

query I
SELECT crc32(arrow_cast('Spark', 'Binary'));
----
1557323817

# Test with LargeBinary
query I
SELECT crc32(arrow_cast(arrow_cast('Spark', 'Binary'), 'FixedSizeBinary(5)'));
----
1557323817

query I
SELECT crc32(arrow_cast('Spark', 'LargeBinary'));
----
1557323817

# Test with BinaryView
query I
SELECT crc32(arrow_cast('Spark', 'BinaryView'));
----
Expand Down
25 changes: 18 additions & 7 deletions datafusion/sqllogictest/test_files/spark/hash/sha1.slt
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,51 @@ SELECT sha1('Spark');
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

query T
SELECT sha1(NULL);
SELECT sha('Spark');
----
NULL
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

query T
SELECT sha1('');
----
da39a3ee5e6b4b0d3255bfef95601890afd80709

# Test with LargeUtf8 (using CAST to ensure type)
# Test with different types
query T
SELECT sha1(NULL);
----
NULL

query T
SELECT sha1(arrow_cast('Spark', 'LargeUtf8'));
----
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

# Test with Utf8View (using CAST to ensure type)
query T
SELECT sha1(arrow_cast('Spark', 'Utf8View'));
----
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

# Test with Binary
query T
SELECT sha1(arrow_cast('Spark', 'Utf8'));
----
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

query T
SELECT sha1(arrow_cast('Spark', 'Binary'));
----
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

# Test with LargeBinary
query T
SELECT sha1(arrow_cast(arrow_cast('Spark', 'Binary'), 'FixedSizeBinary(5)'));
----
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

query T
SELECT sha1(arrow_cast('Spark', 'LargeBinary'));
----
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c

# Test with BinaryView
query T
SELECT sha1(arrow_cast('Spark', 'BinaryView'));
----
Expand Down
Loading