@@ -20,16 +20,16 @@ use std::sync::Arc;
2020
2121use arrow:: array:: ArrayRef ;
2222use arrow:: compute;
23- use arrow:: datatypes:: { DataType , Date32Type } ;
23+ use arrow:: datatypes:: { DataType , Date32Type , Field , FieldRef } ;
2424use arrow:: error:: ArrowError ;
2525use datafusion_common:: cast:: {
2626 as_date32_array, as_int16_array, as_int32_array, as_int8_array,
2727} ;
2828use datafusion_common:: utils:: take_function_args;
2929use datafusion_common:: { internal_err, Result } ;
3030use datafusion_expr:: {
31- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
32- Volatility ,
31+ ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
32+ TypeSignature , Volatility ,
3333} ;
3434use datafusion_functions:: utils:: make_scalar_function;
3535
@@ -79,7 +79,21 @@ impl ScalarUDFImpl for SparkDateAdd {
7979 }
8080
8181 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
82- Ok ( DataType :: Date32 )
82+ internal_err ! ( "Use return_field_from_args in this case instead." )
83+ }
84+
85+ fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
86+ let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) )
87+ || args
88+ . scalar_arguments
89+ . iter ( )
90+ . any ( |arg| matches ! ( arg, Some ( sv) if sv. is_null( ) ) ) ;
91+
92+ Ok ( Arc :: new ( Field :: new (
93+ self . name ( ) ,
94+ DataType :: Date32 ,
95+ nullable,
96+ ) ) )
8397 }
8498
8599 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
@@ -136,3 +150,69 @@ fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
136150 } ;
137151 Ok ( Arc :: new ( result) )
138152}
153+
154+ #[ cfg( test) ]
155+ mod tests {
156+ use super :: * ;
157+ use arrow:: datatypes:: Field ;
158+ use datafusion_common:: ScalarValue ;
159+
160+ #[ test]
161+ fn test_date_add_non_nullable_inputs ( ) {
162+ let func = SparkDateAdd :: new ( ) ;
163+ let args = & [
164+ Arc :: new ( Field :: new ( "date" , DataType :: Date32 , false ) ) ,
165+ Arc :: new ( Field :: new ( "num" , DataType :: Int8 , false ) ) ,
166+ ] ;
167+
168+ let ret_field = func
169+ . return_field_from_args ( ReturnFieldArgs {
170+ arg_fields : args,
171+ scalar_arguments : & [ None , None ] ,
172+ } )
173+ . unwrap ( ) ;
174+
175+ assert_eq ! ( ret_field. data_type( ) , & DataType :: Date32 ) ;
176+ assert ! ( !ret_field. is_nullable( ) ) ;
177+ }
178+
179+ #[ test]
180+ fn test_date_add_nullable_inputs ( ) {
181+ let func = SparkDateAdd :: new ( ) ;
182+ let args = & [
183+ Arc :: new ( Field :: new ( "date" , DataType :: Date32 , false ) ) ,
184+ Arc :: new ( Field :: new ( "num" , DataType :: Int16 , true ) ) ,
185+ ] ;
186+
187+ let ret_field = func
188+ . return_field_from_args ( ReturnFieldArgs {
189+ arg_fields : args,
190+ scalar_arguments : & [ None , None ] ,
191+ } )
192+ . unwrap ( ) ;
193+
194+ assert_eq ! ( ret_field. data_type( ) , & DataType :: Date32 ) ;
195+ assert ! ( ret_field. is_nullable( ) ) ;
196+ }
197+
198+ #[ test]
199+ fn test_date_add_null_scalar ( ) {
200+ let func = SparkDateAdd :: new ( ) ;
201+ let args = & [
202+ Arc :: new ( Field :: new ( "date" , DataType :: Date32 , false ) ) ,
203+ Arc :: new ( Field :: new ( "num" , DataType :: Int32 , false ) ) ,
204+ ] ;
205+
206+ let null_scalar = ScalarValue :: Int32 ( None ) ;
207+
208+ let ret_field = func
209+ . return_field_from_args ( ReturnFieldArgs {
210+ arg_fields : args,
211+ scalar_arguments : & [ None , Some ( & null_scalar) ] ,
212+ } )
213+ . unwrap ( ) ;
214+
215+ assert_eq ! ( ret_field. data_type( ) , & DataType :: Date32 ) ;
216+ assert ! ( ret_field. is_nullable( ) ) ;
217+ }
218+ }
0 commit comments