Skip to content

Commit 42140ae

Browse files
committed
fix: Added nullable return from data_add(apache#19151)
1 parent 4fb36b2 commit 42140ae

File tree

1 file changed

+57
-4
lines changed

1 file changed

+57
-4
lines changed

datafusion/spark/src/function/datetime/date_add.rs

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ use std::sync::Arc;
2020

2121
use arrow::array::ArrayRef;
2222
use arrow::compute;
23-
use arrow::datatypes::{DataType, Date32Type};
23+
use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
2424
use arrow::error::ArrowError;
2525
use datafusion_common::cast::{
2626
as_date32_array, as_int16_array, as_int32_array, as_int8_array,
2727
};
2828
use datafusion_common::utils::take_function_args;
2929
use datafusion_common::{internal_err, Result};
3030
use datafusion_expr::{
31-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32-
Volatility,
31+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32+
TypeSignature, Volatility,
3333
};
3434
use datafusion_functions::utils::make_scalar_function;
3535

@@ -79,7 +79,16 @@ impl ScalarUDFImpl for SparkDateAdd {
7979
}
8080

8181
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
82-
Ok(DataType::Date32)
82+
internal_err!("Use return_field_from_args in this case instead.")
83+
}
84+
85+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
86+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
87+
Ok(Arc::new(Field::new(
88+
self.name(),
89+
DataType::Date32,
90+
nullable,
91+
)))
8392
}
8493

8594
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
@@ -136,3 +145,47 @@ fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
136145
};
137146
Ok(Arc::new(result))
138147
}
148+
149+
#[cfg(test)]
150+
mod tests {
151+
use super::*;
152+
use arrow::datatypes::Field;
153+
154+
#[test]
155+
fn test_date_add_non_nullable_inputs() {
156+
let func = SparkDateAdd::new();
157+
let args = &[
158+
Arc::new(Field::new("arg_date32", DataType::Date32, false)),
159+
Arc::new(Field::new("arg_int16", DataType::Int16, false)),
160+
];
161+
162+
let ret_field = func
163+
.return_field_from_args(ReturnFieldArgs {
164+
arg_fields: args,
165+
scalar_arguments: &[None, None],
166+
})
167+
.unwrap();
168+
169+
assert_eq!(ret_field.data_type(), &DataType::Date32);
170+
assert!(!ret_field.is_nullable());
171+
}
172+
173+
#[test]
174+
fn test_date_add_nullable_inputs() {
175+
let func = SparkDateAdd::new();
176+
let args = &[
177+
Arc::new(Field::new("arg_date32", DataType::Date32, false)),
178+
Arc::new(Field::new("arg_int16", DataType::Int16, true)),
179+
];
180+
181+
let ret_field = func
182+
.return_field_from_args(ReturnFieldArgs {
183+
arg_fields: args,
184+
scalar_arguments: &[None, None],
185+
})
186+
.unwrap();
187+
188+
assert_eq!(ret_field.data_type(), &DataType::Date32);
189+
assert!(ret_field.is_nullable());
190+
}
191+
}

0 commit comments

Comments
 (0)