Skip to content

Commit c184cff

Browse files
author
Han You
committed
Fix RecordBatch::normalize() null bitmap bug and add StructArray::flatten()
Currently RecordBatch::normalize() has a bug in that the top level struct's null bitmap is not propagated into the resulting normalized arrays' null bitmap. In other words, a child element may suddenly appear non-null, losing the fact that the parent level struct is null at that index. See the test in this change for a bug reproduction. This change fixes that behavior. Also adds StructArray::flatten() which mirrors arrow-cpp's semantics and handles the aforementioned behavior correctly. The fixed RecordBatch::normalize() now uses StructArray::flatten() under the hood.
1 parent 38d78c3 commit c184cff

2 files changed

Lines changed: 240 additions & 19 deletions

File tree

arrow-array/src/array/struct_array.rs

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,70 @@ impl StructArray {
343343
fields,
344344
}
345345
}
346+
347+
/// Returns the children of this [`StructArray`] with the struct's validity
348+
/// bitmap AND'd into each child's validity bitmap.
349+
///
350+
/// This ensures that positions where the struct itself is null are also
351+
/// null in each returned child array. Fields that were non-nullable are
352+
/// marked nullable in the returned [`Fields`] when the struct has nulls.
353+
///
354+
/// If the struct has no nulls, children and fields are returned as-is.
355+
///
356+
/// This mirrors the semantics of C++ Arrow's `StructArray::Flatten`.
357+
///
358+
/// # Example
359+
///
360+
/// ```
361+
/// # use std::sync::Arc;
362+
/// # use arrow_array::{Array, ArrayRef, Int32Array, StructArray};
363+
/// # use arrow_buffer::{BooleanBuffer, NullBuffer};
364+
/// # use arrow_schema::{DataType, Field, Fields};
365+
/// let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
366+
/// let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
367+
/// let sa = StructArray::new(
368+
/// Fields::from(vec![Field::new("a", DataType::Int32, false)]),
369+
/// vec![child],
370+
/// Some(struct_nulls),
371+
/// );
372+
/// let (fields, columns) = sa.flatten();
373+
/// assert!(fields[0].is_nullable());
374+
/// assert!(columns[0].is_null(1));
375+
/// ```
376+
pub fn flatten(&self) -> (Fields, Vec<ArrayRef>) {
377+
let schema_fields = self.fields();
378+
379+
let struct_nulls = match &self.nulls {
380+
Some(n) => n,
381+
None => return (schema_fields.clone(), self.fields.clone()),
382+
};
383+
384+
let new_fields: Fields = schema_fields
385+
.iter()
386+
.map(|f| {
387+
if f.is_nullable() {
388+
Arc::clone(f)
389+
} else {
390+
Arc::new(f.as_ref().clone().with_nullable(true))
391+
}
392+
})
393+
.collect::<Vec<_>>()
394+
.into();
395+
396+
let new_columns = self
397+
.fields
398+
.iter()
399+
.map(|child| {
400+
let merged = NullBuffer::union(Some(struct_nulls), child.nulls());
401+
// SAFETY: We only make the null buffer more restrictive (adding nulls).
402+
// All data buffers and child data remain unchanged.
403+
let data = child.to_data().into_builder().nulls(merged);
404+
make_array(unsafe { data.build_unchecked() })
405+
})
406+
.collect();
407+
408+
(new_fields, new_columns)
409+
}
346410
}
347411

348412
impl From<ArrayData> for StructArray {
@@ -958,4 +1022,140 @@ mod tests {
9581022

9591023
StructArray::try_new(fields, arrays, nulls).expect("should not error");
9601024
}
1025+
1026+
#[test]
1027+
fn test_flatten_no_nulls() {
1028+
let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
1029+
let sa = StructArray::from(vec![(
1030+
Arc::new(Field::new("a", DataType::Int32, false)),
1031+
child,
1032+
)]);
1033+
1034+
let (fields, columns) = sa.flatten();
1035+
1036+
assert_eq!(columns.len(), 1);
1037+
assert!(!fields[0].is_nullable());
1038+
assert_eq!(columns[0].null_count(), 0);
1039+
assert_eq!(columns[0].len(), 3);
1040+
}
1041+
1042+
#[test]
1043+
fn test_flatten_struct_nulls_child_no_nulls() {
1044+
let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
1045+
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
1046+
let sa = StructArray::new(
1047+
Fields::from(vec![Field::new("a", DataType::Int32, false)]),
1048+
vec![child],
1049+
Some(struct_nulls),
1050+
);
1051+
1052+
let (fields, columns) = sa.flatten();
1053+
1054+
assert!(fields[0].is_nullable());
1055+
assert!(columns[0].is_valid(0));
1056+
assert!(columns[0].is_null(1));
1057+
assert!(columns[0].is_valid(2));
1058+
assert_eq!(columns[0].null_count(), 1);
1059+
}
1060+
1061+
#[test]
1062+
fn test_flatten_both_have_nulls() {
1063+
// struct validity: [valid, null, valid, valid]
1064+
// child validity: [valid, valid, null, valid]
1065+
// expected: [valid, null, null, valid]
1066+
let child = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef;
1067+
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, true]));
1068+
let sa = StructArray::new(
1069+
Fields::from(vec![Field::new("a", DataType::Int32, true)]),
1070+
vec![child],
1071+
Some(struct_nulls),
1072+
);
1073+
1074+
let (fields, columns) = sa.flatten();
1075+
1076+
assert!(fields[0].is_nullable());
1077+
assert!(columns[0].is_valid(0));
1078+
assert!(columns[0].is_null(1));
1079+
assert!(columns[0].is_null(2));
1080+
assert!(columns[0].is_valid(3));
1081+
assert_eq!(columns[0].null_count(), 2);
1082+
}
1083+
1084+
#[test]
1085+
fn test_flatten_sliced_struct() {
1086+
let child = Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef;
1087+
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, false]));
1088+
let sa = StructArray::new(
1089+
Fields::from(vec![Field::new("a", DataType::Int32, false)]),
1090+
vec![child],
1091+
Some(struct_nulls),
1092+
);
1093+
let sliced = sa.slice(1, 2);
1094+
1095+
let (fields, columns) = sliced.flatten();
1096+
1097+
assert!(fields[0].is_nullable());
1098+
assert_eq!(columns[0].len(), 2);
1099+
assert!(columns[0].is_null(0));
1100+
assert!(columns[0].is_valid(1));
1101+
}
1102+
1103+
#[test]
1104+
fn test_flatten_multiple_children() {
1105+
let int_child = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])) as ArrayRef;
1106+
let str_child = Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef;
1107+
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
1108+
let sa = StructArray::new(
1109+
Fields::from(vec![
1110+
Field::new("ints", DataType::Int32, true),
1111+
Field::new("strs", DataType::Utf8, true),
1112+
]),
1113+
vec![int_child, str_child],
1114+
Some(struct_nulls),
1115+
);
1116+
1117+
let (fields, columns) = sa.flatten();
1118+
1119+
assert_eq!(fields.len(), 2);
1120+
// int: [valid, null(struct), null(child)] => null_count=2
1121+
assert_eq!(columns[0].null_count(), 2);
1122+
assert!(columns[0].is_valid(0));
1123+
assert!(columns[0].is_null(1));
1124+
assert!(columns[0].is_null(2));
1125+
// str: [valid, null(struct+child), valid] => null_count=1
1126+
assert_eq!(columns[1].null_count(), 1);
1127+
assert!(columns[1].is_valid(0));
1128+
assert!(columns[1].is_null(1));
1129+
assert!(columns[1].is_valid(2));
1130+
}
1131+
1132+
#[test]
1133+
fn test_flatten_empty_struct() {
1134+
let sa = StructArray::new_empty_fields(5, Some(NullBuffer::new_null(5)));
1135+
1136+
let (fields, columns) = sa.flatten();
1137+
1138+
assert_eq!(fields.len(), 0);
1139+
assert_eq!(columns.len(), 0);
1140+
}
1141+
1142+
#[test]
1143+
fn test_flatten_field_nullability_update() {
1144+
let non_null_child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
1145+
let nullable_child = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef;
1146+
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, true, false]));
1147+
let sa = StructArray::new(
1148+
Fields::from(vec![
1149+
Field::new("non_null", DataType::Int32, false),
1150+
Field::new("nullable", DataType::Int32, true),
1151+
]),
1152+
vec![non_null_child, nullable_child],
1153+
Some(struct_nulls),
1154+
);
1155+
1156+
let (fields, _columns) = sa.flatten();
1157+
1158+
assert!(fields[0].is_nullable()); // was false, now true
1159+
assert!(fields[1].is_nullable()); // was true, stays true
1160+
}
9611161
}

arrow-array/src/record_batch.rs

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -543,37 +543,29 @@ impl RecordBatch {
543543
0 => usize::MAX,
544544
val => val,
545545
};
546-
let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
546+
let mut stack: Vec<(usize, ArrayRef, String, FieldRef)> = self
547547
.columns
548548
.iter()
549549
.zip(self.schema.fields())
550550
.rev()
551-
.map(|(c, f)| {
552-
let name_vec: Vec<&str> = vec![f.name()];
553-
(0, c, name_vec, f)
554-
})
551+
.map(|(c, f)| (0, c.clone(), f.name().clone(), Arc::clone(f)))
555552
.collect();
556553
let mut columns: Vec<ArrayRef> = Vec::new();
557554
let mut fields: Vec<FieldRef> = Vec::new();
558555

559556
while let Some((depth, c, name, field_ref)) = stack.pop() {
560557
match field_ref.data_type() {
561-
DataType::Struct(ff) if depth < max_level => {
562-
// Need to zip these in reverse to maintain original order
563-
for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
564-
let mut name = name.clone();
565-
name.push(separator);
566-
name.push(fff.name());
567-
stack.push((depth + 1, cff, name, fff))
558+
DataType::Struct(_) if depth < max_level => {
559+
let (flat_fields, flat_cols) = c.as_struct().flatten();
560+
for (cff, fff) in flat_cols.into_iter().zip(flat_fields.iter()).rev() {
561+
let child_name = format!("{name}{separator}{}", fff.name());
562+
stack.push((depth + 1, cff, child_name, Arc::clone(fff)))
568563
}
569564
}
570565
_ => {
571-
let updated_field = Field::new(
572-
name.concat(),
573-
field_ref.data_type().clone(),
574-
field_ref.is_nullable(),
575-
);
576-
columns.push(c.clone());
566+
let updated_field =
567+
Field::new(name, field_ref.data_type().clone(), field_ref.is_nullable());
568+
columns.push(c);
577569
fields.push(Arc::new(updated_field));
578570
}
579571
}
@@ -973,7 +965,7 @@ mod tests {
973965
use crate::{
974966
BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
975967
};
976-
use arrow_buffer::{Buffer, ToByteSlice};
968+
use arrow_buffer::{Buffer, NullBuffer, ToByteSlice};
977969
use arrow_data::{ArrayData, ArrayDataBuilder};
978970
use arrow_schema::Fields;
979971
use std::collections::HashMap;
@@ -1771,4 +1763,33 @@ mod tests {
17711763
"bar"
17721764
);
17731765
}
1766+
1767+
#[test]
1768+
fn test_normalize_nullable_struct() {
1769+
let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
1770+
let struct_nulls =
1771+
NullBuffer::new(arrow_buffer::BooleanBuffer::from(vec![true, false, true]));
1772+
let struct_array = Arc::new(StructArray::new(
1773+
Fields::from(vec![Field::new("x", DataType::Int32, false)]),
1774+
vec![child],
1775+
Some(struct_nulls),
1776+
)) as ArrayRef;
1777+
1778+
let schema = Schema::new(vec![Field::new(
1779+
"s",
1780+
DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int32, false)])),
1781+
true,
1782+
)]);
1783+
let batch = RecordBatch::try_new(Arc::new(schema), vec![struct_array]).unwrap();
1784+
1785+
let normalized = batch.normalize(".", None).unwrap();
1786+
1787+
assert_eq!(normalized.num_columns(), 1);
1788+
assert_eq!(normalized.schema().field(0).name(), "s.x");
1789+
assert!(normalized.schema().field(0).is_nullable());
1790+
let col = normalized.column(0);
1791+
assert!(col.is_valid(0));
1792+
assert!(col.is_null(1));
1793+
assert!(col.is_valid(2));
1794+
}
17741795
}

0 commit comments

Comments
 (0)