Skip to content

Commit f0dcc89

Browse files
committed
fix: Added nullable return from data_add(apache#19151)
1 parent 4fb36b2 commit f0dcc89

File tree

1 file changed

+84
-4
lines changed

1 file changed

+84
-4
lines changed

datafusion/spark/src/function/datetime/date_add.rs

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ use std::sync::Arc;
2020

2121
use arrow::array::ArrayRef;
2222
use arrow::compute;
23-
use arrow::datatypes::{DataType, Date32Type};
23+
use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
2424
use arrow::error::ArrowError;
2525
use datafusion_common::cast::{
2626
as_date32_array, as_int16_array, as_int32_array, as_int8_array,
2727
};
2828
use datafusion_common::utils::take_function_args;
2929
use datafusion_common::{internal_err, Result};
3030
use datafusion_expr::{
31-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32-
Volatility,
31+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32+
TypeSignature, Volatility,
3333
};
3434
use datafusion_functions::utils::make_scalar_function;
3535

@@ -79,7 +79,21 @@ impl ScalarUDFImpl for SparkDateAdd {
7979
}
8080

8181
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
82-
Ok(DataType::Date32)
82+
internal_err!("Use return_field_from_args in this case instead.")
83+
}
84+
85+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
86+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable())
87+
|| args
88+
.scalar_arguments
89+
.iter()
90+
.any(|arg| matches!(arg, Some(sv) if sv.is_null()));
91+
92+
Ok(Arc::new(Field::new(
93+
self.name(),
94+
DataType::Date32,
95+
nullable,
96+
)))
8397
}
8498

8599
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
@@ -136,3 +150,69 @@ fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
136150
};
137151
Ok(Arc::new(result))
138152
}
153+
154+
#[cfg(test)]
155+
mod tests {
156+
use super::*;
157+
use arrow::datatypes::Field;
158+
use datafusion_common::ScalarValue;
159+
160+
#[test]
161+
fn test_date_add_non_nullable_inputs() {
162+
let func = SparkDateAdd::new();
163+
let args = &[
164+
Arc::new(Field::new("date", DataType::Date32, false)),
165+
Arc::new(Field::new("num", DataType::Int8, false)),
166+
];
167+
168+
let ret_field = func
169+
.return_field_from_args(ReturnFieldArgs {
170+
arg_fields: args,
171+
scalar_arguments: &[None, None],
172+
})
173+
.unwrap();
174+
175+
assert_eq!(ret_field.data_type(), &DataType::Date32);
176+
assert!(!ret_field.is_nullable());
177+
}
178+
179+
#[test]
180+
fn test_date_add_nullable_inputs() {
181+
let func = SparkDateAdd::new();
182+
let args = &[
183+
Arc::new(Field::new("date", DataType::Date32, false)),
184+
Arc::new(Field::new("num", DataType::Int16, true)),
185+
];
186+
187+
let ret_field = func
188+
.return_field_from_args(ReturnFieldArgs {
189+
arg_fields: args,
190+
scalar_arguments: &[None, None],
191+
})
192+
.unwrap();
193+
194+
assert_eq!(ret_field.data_type(), &DataType::Date32);
195+
assert!(ret_field.is_nullable());
196+
}
197+
198+
#[test]
199+
fn test_date_add_null_scalar() {
200+
let func = SparkDateAdd::new();
201+
let args = &[
202+
Arc::new(Field::new("date", DataType::Date32, false)),
203+
Arc::new(Field::new("num", DataType::Int32, false)),
204+
];
205+
206+
let null_scalar = ScalarValue::Int32(None);
207+
208+
let ret_field = func
209+
.return_field_from_args(ReturnFieldArgs {
210+
arg_fields: args,
211+
scalar_arguments: &[None, Some(&null_scalar)],
212+
})
213+
.unwrap();
214+
215+
assert_eq!(ret_field.data_type(), &DataType::Date32);
216+
assert!(ret_field.is_nullable());
217+
}
218+
}

0 commit comments

Comments
 (0)