@@ -343,6 +343,70 @@ impl StructArray {
343343 fields,
344344 }
345345 }
346+
347+ /// Returns the children of this [`StructArray`] with the struct's validity
348+ /// bitmap AND'd into each child's validity bitmap.
349+ ///
350+ /// This ensures that positions where the struct itself is null are also
351+ /// null in each returned child array. Fields that were non-nullable are
352+ /// marked nullable in the returned [`Fields`] when the struct has nulls.
353+ ///
354+ /// If the struct has no nulls, children and fields are returned as-is.
355+ ///
356+ /// This mirrors the semantics of C++ Arrow's `StructArray::Flatten`.
357+ ///
358+ /// # Example
359+ ///
360+ /// ```
361+ /// # use std::sync::Arc;
362+ /// # use arrow_array::{Array, ArrayRef, Int32Array, StructArray};
363+ /// # use arrow_buffer::{BooleanBuffer, NullBuffer};
364+ /// # use arrow_schema::{DataType, Field, Fields};
365+ /// let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
366+ /// let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
367+ /// let sa = StructArray::new(
368+ /// Fields::from(vec![Field::new("a", DataType::Int32, false)]),
369+ /// vec![child],
370+ /// Some(struct_nulls),
371+ /// );
372+ /// let (fields, columns) = sa.flatten();
373+ /// assert!(fields[0].is_nullable());
374+ /// assert!(columns[0].is_null(1));
375+ /// ```
376+ pub fn flatten ( & self ) -> ( Fields , Vec < ArrayRef > ) {
377+ let schema_fields = self . fields ( ) ;
378+
379+ let struct_nulls = match & self . nulls {
380+ Some ( n) => n,
381+ None => return ( schema_fields. clone ( ) , self . fields . clone ( ) ) ,
382+ } ;
383+
384+ let new_fields: Fields = schema_fields
385+ . iter ( )
386+ . map ( |f| {
387+ if f. is_nullable ( ) {
388+ Arc :: clone ( f)
389+ } else {
390+ Arc :: new ( f. as_ref ( ) . clone ( ) . with_nullable ( true ) )
391+ }
392+ } )
393+ . collect :: < Vec < _ > > ( )
394+ . into ( ) ;
395+
396+ let new_columns = self
397+ . fields
398+ . iter ( )
399+ . map ( |child| {
400+ let merged = NullBuffer :: union ( Some ( struct_nulls) , child. nulls ( ) ) ;
401+ // SAFETY: We only make the null buffer more restrictive (adding nulls).
402+ // All data buffers and child data remain unchanged.
403+ let data = child. to_data ( ) . into_builder ( ) . nulls ( merged) ;
404+ make_array ( unsafe { data. build_unchecked ( ) } )
405+ } )
406+ . collect ( ) ;
407+
408+ ( new_fields, new_columns)
409+ }
346410}
347411
348412impl From < ArrayData > for StructArray {
@@ -958,4 +1022,140 @@ mod tests {
9581022
9591023 StructArray :: try_new ( fields, arrays, nulls) . expect ( "should not error" ) ;
9601024 }
1025+
1026+ #[ test]
1027+ fn test_flatten_no_nulls ( ) {
1028+ let child = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) as ArrayRef ;
1029+ let sa = StructArray :: from ( vec ! [ (
1030+ Arc :: new( Field :: new( "a" , DataType :: Int32 , false ) ) ,
1031+ child,
1032+ ) ] ) ;
1033+
1034+ let ( fields, columns) = sa. flatten ( ) ;
1035+
1036+ assert_eq ! ( columns. len( ) , 1 ) ;
1037+ assert ! ( !fields[ 0 ] . is_nullable( ) ) ;
1038+ assert_eq ! ( columns[ 0 ] . null_count( ) , 0 ) ;
1039+ assert_eq ! ( columns[ 0 ] . len( ) , 3 ) ;
1040+ }
1041+
1042+ #[ test]
1043+ fn test_flatten_struct_nulls_child_no_nulls ( ) {
1044+ let child = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) as ArrayRef ;
1045+ let struct_nulls = NullBuffer :: new ( BooleanBuffer :: from ( vec ! [ true , false , true ] ) ) ;
1046+ let sa = StructArray :: new (
1047+ Fields :: from ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ,
1048+ vec ! [ child] ,
1049+ Some ( struct_nulls) ,
1050+ ) ;
1051+
1052+ let ( fields, columns) = sa. flatten ( ) ;
1053+
1054+ assert ! ( fields[ 0 ] . is_nullable( ) ) ;
1055+ assert ! ( columns[ 0 ] . is_valid( 0 ) ) ;
1056+ assert ! ( columns[ 0 ] . is_null( 1 ) ) ;
1057+ assert ! ( columns[ 0 ] . is_valid( 2 ) ) ;
1058+ assert_eq ! ( columns[ 0 ] . null_count( ) , 1 ) ;
1059+ }
1060+
1061+ #[ test]
1062+ fn test_flatten_both_have_nulls ( ) {
1063+ // struct validity: [valid, null, valid, valid]
1064+ // child validity: [valid, valid, null, valid]
1065+ // expected: [valid, null, null, valid]
1066+ let child = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None , Some ( 4 ) ] ) ) as ArrayRef ;
1067+ let struct_nulls = NullBuffer :: new ( BooleanBuffer :: from ( vec ! [ true , false , true , true ] ) ) ;
1068+ let sa = StructArray :: new (
1069+ Fields :: from ( vec ! [ Field :: new( "a" , DataType :: Int32 , true ) ] ) ,
1070+ vec ! [ child] ,
1071+ Some ( struct_nulls) ,
1072+ ) ;
1073+
1074+ let ( fields, columns) = sa. flatten ( ) ;
1075+
1076+ assert ! ( fields[ 0 ] . is_nullable( ) ) ;
1077+ assert ! ( columns[ 0 ] . is_valid( 0 ) ) ;
1078+ assert ! ( columns[ 0 ] . is_null( 1 ) ) ;
1079+ assert ! ( columns[ 0 ] . is_null( 2 ) ) ;
1080+ assert ! ( columns[ 0 ] . is_valid( 3 ) ) ;
1081+ assert_eq ! ( columns[ 0 ] . null_count( ) , 2 ) ;
1082+ }
1083+
1084+ #[ test]
1085+ fn test_flatten_sliced_struct ( ) {
1086+ let child = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 ] ) ) as ArrayRef ;
1087+ let struct_nulls = NullBuffer :: new ( BooleanBuffer :: from ( vec ! [ true , false , true , false ] ) ) ;
1088+ let sa = StructArray :: new (
1089+ Fields :: from ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ,
1090+ vec ! [ child] ,
1091+ Some ( struct_nulls) ,
1092+ ) ;
1093+ let sliced = sa. slice ( 1 , 2 ) ;
1094+
1095+ let ( fields, columns) = sliced. flatten ( ) ;
1096+
1097+ assert ! ( fields[ 0 ] . is_nullable( ) ) ;
1098+ assert_eq ! ( columns[ 0 ] . len( ) , 2 ) ;
1099+ assert ! ( columns[ 0 ] . is_null( 0 ) ) ;
1100+ assert ! ( columns[ 0 ] . is_valid( 1 ) ) ;
1101+ }
1102+
1103+ #[ test]
1104+ fn test_flatten_multiple_children ( ) {
1105+ let int_child = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None ] ) ) as ArrayRef ;
1106+ let str_child = Arc :: new ( StringArray :: from ( vec ! [ Some ( "a" ) , None , Some ( "c" ) ] ) ) as ArrayRef ;
1107+ let struct_nulls = NullBuffer :: new ( BooleanBuffer :: from ( vec ! [ true , false , true ] ) ) ;
1108+ let sa = StructArray :: new (
1109+ Fields :: from ( vec ! [
1110+ Field :: new( "ints" , DataType :: Int32 , true ) ,
1111+ Field :: new( "strs" , DataType :: Utf8 , true ) ,
1112+ ] ) ,
1113+ vec ! [ int_child, str_child] ,
1114+ Some ( struct_nulls) ,
1115+ ) ;
1116+
1117+ let ( fields, columns) = sa. flatten ( ) ;
1118+
1119+ assert_eq ! ( fields. len( ) , 2 ) ;
1120+ // int: [valid, null(struct), null(child)] => null_count=2
1121+ assert_eq ! ( columns[ 0 ] . null_count( ) , 2 ) ;
1122+ assert ! ( columns[ 0 ] . is_valid( 0 ) ) ;
1123+ assert ! ( columns[ 0 ] . is_null( 1 ) ) ;
1124+ assert ! ( columns[ 0 ] . is_null( 2 ) ) ;
1125+ // str: [valid, null(struct+child), valid] => null_count=1
1126+ assert_eq ! ( columns[ 1 ] . null_count( ) , 1 ) ;
1127+ assert ! ( columns[ 1 ] . is_valid( 0 ) ) ;
1128+ assert ! ( columns[ 1 ] . is_null( 1 ) ) ;
1129+ assert ! ( columns[ 1 ] . is_valid( 2 ) ) ;
1130+ }
1131+
1132+ #[ test]
1133+ fn test_flatten_empty_struct ( ) {
1134+ let sa = StructArray :: new_empty_fields ( 5 , Some ( NullBuffer :: new_null ( 5 ) ) ) ;
1135+
1136+ let ( fields, columns) = sa. flatten ( ) ;
1137+
1138+ assert_eq ! ( fields. len( ) , 0 ) ;
1139+ assert_eq ! ( columns. len( ) , 0 ) ;
1140+ }
1141+
1142+ #[ test]
1143+ fn test_flatten_field_nullability_update ( ) {
1144+ let non_null_child = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) as ArrayRef ;
1145+ let nullable_child = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) ] ) ) as ArrayRef ;
1146+ let struct_nulls = NullBuffer :: new ( BooleanBuffer :: from ( vec ! [ true , true , false ] ) ) ;
1147+ let sa = StructArray :: new (
1148+ Fields :: from ( vec ! [
1149+ Field :: new( "non_null" , DataType :: Int32 , false ) ,
1150+ Field :: new( "nullable" , DataType :: Int32 , true ) ,
1151+ ] ) ,
1152+ vec ! [ non_null_child, nullable_child] ,
1153+ Some ( struct_nulls) ,
1154+ ) ;
1155+
1156+ let ( fields, _columns) = sa. flatten ( ) ;
1157+
1158+ assert ! ( fields[ 0 ] . is_nullable( ) ) ; // was false, now true
1159+ assert ! ( fields[ 1 ] . is_nullable( ) ) ; // was true, stays true
1160+ }
9611161}
0 commit comments