@@ -20,16 +20,16 @@ use std::sync::Arc;
2020
2121use arrow:: array:: ArrayRef ;
2222use arrow:: compute;
23- use arrow:: datatypes:: { DataType , Date32Type } ;
23+ use arrow:: datatypes:: { DataType , Date32Type , Field , FieldRef } ;
2424use arrow:: error:: ArrowError ;
2525use datafusion_common:: cast:: {
2626 as_date32_array, as_int16_array, as_int32_array, as_int8_array,
2727} ;
2828use datafusion_common:: utils:: take_function_args;
2929use datafusion_common:: { internal_err, Result } ;
3030use datafusion_expr:: {
31- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
32- Volatility ,
31+ ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
32+ TypeSignature , Volatility ,
3333} ;
3434use datafusion_functions:: utils:: make_scalar_function;
3535
@@ -79,7 +79,16 @@ impl ScalarUDFImpl for SparkDateAdd {
7979 }
8080
8181 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
82- Ok ( DataType :: Date32 )
82+ internal_err ! ( "Use return_field_from_args in this case instead." )
83+ }
84+
85+ fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
86+ let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) ) ;
87+ Ok ( Arc :: new ( Field :: new (
88+ self . name ( ) ,
89+ DataType :: Date32 ,
90+ nullable,
91+ ) ) )
8392 }
8493
8594 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
@@ -136,3 +145,47 @@ fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
136145 } ;
137146 Ok ( Arc :: new ( result) )
138147}
148+
149+ #[ cfg( test) ]
150+ mod tests {
151+ use super :: * ;
152+ use arrow:: datatypes:: Field ;
153+
154+ #[ test]
155+ fn test_date_add_non_nullable_inputs ( ) {
156+ let func = SparkDateAdd :: new ( ) ;
157+ let args = & [
158+ Arc :: new ( Field :: new ( "arg_date32" , DataType :: Date32 , false ) ) ,
159+ Arc :: new ( Field :: new ( "arg_int16" , DataType :: Int16 , false ) ) ,
160+ ] ;
161+
162+ let ret_field = func
163+ . return_field_from_args ( ReturnFieldArgs {
164+ arg_fields : args,
165+ scalar_arguments : & [ None , None ] ,
166+ } )
167+ . unwrap ( ) ;
168+
169+ assert_eq ! ( ret_field. data_type( ) , & DataType :: Date32 ) ;
170+ assert ! ( !ret_field. is_nullable( ) ) ;
171+ }
172+
173+ #[ test]
174+ fn test_date_add_nullable_inputs ( ) {
175+ let func = SparkDateAdd :: new ( ) ;
176+ let args = & [
177+ Arc :: new ( Field :: new ( "arg_date32" , DataType :: Date32 , false ) ) ,
178+ Arc :: new ( Field :: new ( "arg_int16" , DataType :: Int16 , true ) ) ,
179+ ] ;
180+
181+ let ret_field = func
182+ . return_field_from_args ( ReturnFieldArgs {
183+ arg_fields : args,
184+ scalar_arguments : & [ None , None ] ,
185+ } )
186+ . unwrap ( ) ;
187+
188+ assert_eq ! ( ret_field. data_type( ) , & DataType :: Date32 ) ;
189+ assert ! ( ret_field. is_nullable( ) ) ;
190+ }
191+ }
0 commit comments