Skip to content

Commit ba7698c

Browse files
Fix panic
1 parent 20cd096 commit ba7698c

File tree

1 file changed

+161
-3
lines changed

1 file changed

+161
-3
lines changed

arrow-row/src/lib.rs

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,14 +1903,13 @@ unsafe fn decode_column(
19031903

19041904
let child_row = &row[1..];
19051905
rows_by_field[field_idx].push((idx, child_row));
1906-
1907-
*row = &row[row.len()..];
19081906
}
19091907

19101908
let mut child_arrays: Vec<ArrayRef> = Vec::with_capacity(converters.len());
1911-
19121909
let mut offsets = (*mode == UnionMode::Dense).then(|| Vec::with_capacity(len));
19131910

1911+
let mut bytes_consumed = vec![0usize; len];
1912+
19141913
for (field_idx, converter) in converters.iter().enumerate() {
19151914
let field_rows = &rows_by_field[field_idx];
19161915

@@ -1930,6 +1929,12 @@ unsafe fn decode_column(
19301929
let child_array =
19311930
unsafe { converter.convert_raw(&mut child_data, validate_utf8) }?;
19321931

1932+
// track bytes consumed by comparing original and remaining lengths
1933+
for (i, (row_idx, child_row)) in field_rows.iter().enumerate() {
1934+
let remaining_len = child_data[i].len();
1935+
bytes_consumed[*row_idx] = 1 + child_row.len() - remaining_len;
1936+
}
1937+
19331938
child_arrays.push(child_array.into_iter().next().unwrap());
19341939
}
19351940
UnionMode::Sparse => {
@@ -1951,11 +1956,26 @@ unsafe fn decode_column(
19511956

19521957
let child_array =
19531958
unsafe { converter.convert_raw(&mut sparse_data, validate_utf8) }?;
1959+
1960+
// track bytes consumed for rows that belong to this field
1961+
for (row_idx, child_row) in field_rows.iter() {
1962+
let remaining_len = sparse_data[*row_idx].len();
1963+
bytes_consumed[*row_idx] = 1 + child_row.len() - remaining_len;
1964+
}
1965+
19541966
child_arrays.push(child_array.into_iter().next().unwrap());
19551967
}
19561968
}
19571969
}
19581970

1971+
// advance all row slices by the bytes consumed
1972+
// this is necessary when multiple columns exist, since each decoder needs to
1973+
// advance the row slice by the bytes it consumed so the next column's decoder
1974+
// can read its data
1975+
for (i, row) in rows.iter_mut().enumerate() {
1976+
*row = &row[bytes_consumed[i]..];
1977+
}
1978+
19591979
// build offsets for dense unions
19601980
if let Some(ref mut offsets_vec) = offsets {
19611981
let mut count = vec![0i32; converters.len()];
@@ -4050,4 +4070,142 @@ mod tests {
40504070
// "a" < "z"
40514071
assert!(rows.row(3) < rows.row(1));
40524072
}
4073+
4074+
#[test]
4075+
fn test_row_converter_roundtrip_with_many_union_columns() {
4076+
// col 1: Union(Int32, Utf8) [67, "hello"]
4077+
let fields1 = UnionFields::try_new(
4078+
vec![0, 1],
4079+
vec![
4080+
Field::new("int", DataType::Int32, true),
4081+
Field::new("string", DataType::Utf8, true),
4082+
],
4083+
)
4084+
.unwrap();
4085+
4086+
let int_array1 = Int32Array::from(vec![Some(67), None]);
4087+
let string_array1 = StringArray::from(vec![None::<&str>, Some("hello")]);
4088+
let type_ids1 = vec![0i8, 1].into();
4089+
4090+
let union_array1 = UnionArray::try_new(
4091+
fields1.clone(),
4092+
type_ids1,
4093+
None,
4094+
vec![
4095+
Arc::new(int_array1) as ArrayRef,
4096+
Arc::new(string_array1) as ArrayRef,
4097+
],
4098+
)
4099+
.unwrap();
4100+
4101+
// col 2: Union(Int32, Utf8) [100, "world"]
4102+
let fields2 = UnionFields::try_new(
4103+
vec![0, 1],
4104+
vec![
4105+
Field::new("int", DataType::Int32, true),
4106+
Field::new("string", DataType::Utf8, true),
4107+
],
4108+
)
4109+
.unwrap();
4110+
4111+
let int_array2 = Int32Array::from(vec![Some(100), None]);
4112+
let string_array2 = StringArray::from(vec![None::<&str>, Some("world")]);
4113+
let type_ids2 = vec![0i8, 1].into();
4114+
4115+
let union_array2 = UnionArray::try_new(
4116+
fields2.clone(),
4117+
type_ids2,
4118+
None,
4119+
vec![
4120+
Arc::new(int_array2) as ArrayRef,
4121+
Arc::new(string_array2) as ArrayRef,
4122+
],
4123+
)
4124+
.unwrap();
4125+
4126+
// create a row converter with 2 union columns
4127+
let field1 = Field::new("col1", DataType::Union(fields1, UnionMode::Sparse), true);
4128+
let field2 = Field::new("col2", DataType::Union(fields2, UnionMode::Sparse), true);
4129+
4130+
let sort_field1 = SortField::new(field1.data_type().clone());
4131+
let sort_field2 = SortField::new(field2.data_type().clone());
4132+
4133+
let converter = RowConverter::new(vec![sort_field1, sort_field2]).unwrap();
4134+
4135+
let rows = converter
4136+
.convert_columns(&[
4137+
Arc::new(union_array1.clone()) as ArrayRef,
4138+
Arc::new(union_array2.clone()) as ArrayRef,
4139+
])
4140+
.unwrap();
4141+
4142+
// roundtrip
4143+
let out = converter.convert_rows(&rows).unwrap();
4144+
4145+
let [col1, col2] = out.as_slice() else {
4146+
panic!("expected 2 columns")
4147+
};
4148+
4149+
let col1 = col1.as_any().downcast_ref::<UnionArray>().unwrap();
4150+
let col2 = col2.as_any().downcast_ref::<UnionArray>().unwrap();
4151+
4152+
for (expected, got) in [union_array1, union_array2].iter().zip([col1, col2]) {
4153+
assert_eq!(expected.len(), got.len());
4154+
assert_eq!(expected.type_ids(), got.type_ids());
4155+
4156+
for i in 0..expected.len() {
4157+
assert_eq!(expected.value(i).as_ref(), got.value(i).as_ref());
4158+
}
4159+
}
4160+
}
4161+
4162+
#[test]
4163+
fn test_row_converter_roundtrip_with_one_union_column() {
4164+
let fields = UnionFields::try_new(
4165+
vec![0, 1],
4166+
vec![
4167+
Field::new("int", DataType::Int32, true),
4168+
Field::new("string", DataType::Utf8, true),
4169+
],
4170+
)
4171+
.unwrap();
4172+
4173+
let int_array = Int32Array::from(vec![Some(67), None]);
4174+
let string_array = StringArray::from(vec![None::<&str>, Some("hello")]);
4175+
let type_ids = vec![0i8, 1].into();
4176+
4177+
let union_array = UnionArray::try_new(
4178+
fields.clone(),
4179+
type_ids,
4180+
None,
4181+
vec![
4182+
Arc::new(int_array) as ArrayRef,
4183+
Arc::new(string_array) as ArrayRef,
4184+
],
4185+
)
4186+
.unwrap();
4187+
4188+
let field = Field::new("col", DataType::Union(fields, UnionMode::Sparse), true);
4189+
let sort_field = SortField::new(field.data_type().clone());
4190+
let converter = RowConverter::new(vec![sort_field]).unwrap();
4191+
4192+
let rows = converter
4193+
.convert_columns(&[Arc::new(union_array.clone()) as ArrayRef])
4194+
.unwrap();
4195+
4196+
// roundtrip
4197+
let out = converter.convert_rows(&rows).unwrap();
4198+
4199+
let [col1] = out.as_slice() else {
4200+
panic!("expected 1 column")
4201+
};
4202+
4203+
let col = col1.as_any().downcast_ref::<UnionArray>().unwrap();
4204+
assert_eq!(col.len(), union_array.len());
4205+
assert_eq!(col.type_ids(), union_array.type_ids());
4206+
4207+
for i in 0..col.len() {
4208+
assert_eq!(col.value(i).as_ref(), union_array.value(i).as_ref());
4209+
}
4210+
}
40534211
}

0 commit comments

Comments
 (0)