Skip to content

Commit c64321b

Browse files
Fix panic
1 parent 20cd096 commit c64321b

File tree

1 file changed

+158
-3
lines changed

1 file changed

+158
-3
lines changed

arrow-row/src/lib.rs

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,14 +1903,13 @@ unsafe fn decode_column(
19031903

19041904
let child_row = &row[1..];
19051905
rows_by_field[field_idx].push((idx, child_row));
1906-
1907-
*row = &row[row.len()..];
19081906
}
19091907

19101908
let mut child_arrays: Vec<ArrayRef> = Vec::with_capacity(converters.len());
1911-
19121909
let mut offsets = (*mode == UnionMode::Dense).then(|| Vec::with_capacity(len));
19131910

1911+
let mut bytes_consumed = vec![0usize; len];
1912+
19141913
for (field_idx, converter) in converters.iter().enumerate() {
19151914
let field_rows = &rows_by_field[field_idx];
19161915

@@ -1930,6 +1929,12 @@ unsafe fn decode_column(
19301929
let child_array =
19311930
unsafe { converter.convert_raw(&mut child_data, validate_utf8) }?;
19321931

1932+
// track bytes consumed by comparing original and remaining lengths
1933+
for (i, (row_idx, child_row)) in field_rows.iter().enumerate() {
1934+
let remaining_len = child_data[i].len();
1935+
bytes_consumed[*row_idx] = 1 + child_row.len() - remaining_len;
1936+
}
1937+
19331938
child_arrays.push(child_array.into_iter().next().unwrap());
19341939
}
19351940
UnionMode::Sparse => {
@@ -1951,11 +1956,26 @@ unsafe fn decode_column(
19511956

19521957
let child_array =
19531958
unsafe { converter.convert_raw(&mut sparse_data, validate_utf8) }?;
1959+
1960+
// track bytes consumed for rows that belong to this field
1961+
for (row_idx, child_row) in field_rows.iter() {
1962+
let remaining_len = sparse_data[*row_idx].len();
1963+
bytes_consumed[*row_idx] = 1 + child_row.len() - remaining_len;
1964+
}
1965+
19541966
child_arrays.push(child_array.into_iter().next().unwrap());
19551967
}
19561968
}
19571969
}
19581970

1971+
// advance all row slices by the bytes consumed
1972+
// this is necessary when multiple columns exist, since each decoder needs to
1973+
// advance the row slice by the bytes it consumed so the next column's decoder
1974+
// can read its data
1975+
for (i, row) in rows.iter_mut().enumerate() {
1976+
*row = &row[bytes_consumed[i]..];
1977+
}
1978+
19591979
// build offsets for dense unions
19601980
if let Some(ref mut offsets_vec) = offsets {
19611981
let mut count = vec![0i32; converters.len()];
@@ -4050,4 +4070,139 @@ mod tests {
40504070
// "a" < "z"
40514071
assert!(rows.row(3) < rows.row(1));
40524072
}
4073+
4074+
#[test]
4075+
fn test_row_converter_roundtrip_with_many_union_columns() {
4076+
// col 1: Union(Int32, Utf8) [67, "hello"]
4077+
let fields1 = UnionFields::new(
4078+
vec![0, 1],
4079+
vec![
4080+
Field::new("int", DataType::Int32, true),
4081+
Field::new("string", DataType::Utf8, true),
4082+
],
4083+
);
4084+
4085+
let int_array1 = Int32Array::from(vec![Some(67), None]);
4086+
let string_array1 = StringArray::from(vec![None::<&str>, Some("hello")]);
4087+
let type_ids1 = vec![0i8, 1].into();
4088+
4089+
let union_array1 = UnionArray::try_new(
4090+
fields1.clone(),
4091+
type_ids1,
4092+
None,
4093+
vec![
4094+
Arc::new(int_array1) as ArrayRef,
4095+
Arc::new(string_array1) as ArrayRef,
4096+
],
4097+
)
4098+
.unwrap();
4099+
4100+
// col 2: Union(Int32, Utf8) [100, "world"]
4101+
let fields2 = UnionFields::new(
4102+
vec![0, 1],
4103+
vec![
4104+
Field::new("int", DataType::Int32, true),
4105+
Field::new("string", DataType::Utf8, true),
4106+
],
4107+
);
4108+
4109+
let int_array2 = Int32Array::from(vec![Some(100), None]);
4110+
let string_array2 = StringArray::from(vec![None::<&str>, Some("world")]);
4111+
let type_ids2 = vec![0i8, 1].into();
4112+
4113+
let union_array2 = UnionArray::try_new(
4114+
fields2.clone(),
4115+
type_ids2,
4116+
None,
4117+
vec![
4118+
Arc::new(int_array2) as ArrayRef,
4119+
Arc::new(string_array2) as ArrayRef,
4120+
],
4121+
)
4122+
.unwrap();
4123+
4124+
// create a row converter with 2 union columns
4125+
let field1 = Field::new("col1", DataType::Union(fields1, UnionMode::Sparse), true);
4126+
let field2 = Field::new("col2", DataType::Union(fields2, UnionMode::Sparse), true);
4127+
4128+
let sort_field1 = SortField::new(field1.data_type().clone());
4129+
let sort_field2 = SortField::new(field2.data_type().clone());
4130+
4131+
let converter = RowConverter::new(vec![sort_field1, sort_field2]).unwrap();
4132+
4133+
let rows = converter
4134+
.convert_columns(&[
4135+
Arc::new(union_array1.clone()) as ArrayRef,
4136+
Arc::new(union_array2.clone()) as ArrayRef,
4137+
])
4138+
.unwrap();
4139+
4140+
// roundtrip
4141+
let out = converter.convert_rows(&rows).unwrap();
4142+
4143+
let [col1, col2] = out.as_slice() else {
4144+
panic!("expected 2 columns")
4145+
};
4146+
4147+
let col1 = col1.as_any().downcast_ref::<UnionArray>().unwrap();
4148+
let col2 = col2.as_any().downcast_ref::<UnionArray>().unwrap();
4149+
4150+
for (expected, got) in [union_array1, union_array2].iter().zip([col1, col2]) {
4151+
assert_eq!(expected.len(), got.len());
4152+
assert_eq!(expected.type_ids(), got.type_ids());
4153+
4154+
for i in 0..expected.len() {
4155+
assert_eq!(expected.value(i).as_ref(), got.value(i).as_ref());
4156+
}
4157+
}
4158+
}
4159+
4160+
#[test]
4161+
fn test_row_converter_roundtrip_with_one_union_column() {
4162+
let fields = UnionFields::new(
4163+
vec![0, 1],
4164+
vec![
4165+
Field::new("int", DataType::Int32, true),
4166+
Field::new("string", DataType::Utf8, true),
4167+
],
4168+
);
4169+
4170+
let int_array = Int32Array::from(vec![Some(67), None]);
4171+
let string_array = StringArray::from(vec![None::<&str>, Some("hello")]);
4172+
let type_ids = vec![0i8, 1].into();
4173+
4174+
let union_array = UnionArray::try_new(
4175+
fields.clone(),
4176+
type_ids,
4177+
None,
4178+
vec![
4179+
Arc::new(int_array) as ArrayRef,
4180+
Arc::new(string_array) as ArrayRef,
4181+
],
4182+
)
4183+
.unwrap();
4184+
4185+
let field = Field::new("col", DataType::Union(fields, UnionMode::Sparse), true);
4186+
let sort_field = SortField::new(field.data_type().clone());
4187+
let converter = RowConverter::new(vec![sort_field]).unwrap();
4188+
4189+
let rows = converter
4190+
.convert_columns(&[Arc::new(union_array.clone()) as ArrayRef])
4191+
.unwrap();
4192+
4193+
// roundtrip
4194+
let out = converter.convert_rows(&rows).unwrap();
4195+
4196+
let [col1] = out.as_slice() else {
4197+
panic!("expected 1 column")
4198+
};
4199+
4200+
let col = col1.as_any().downcast_ref::<UnionArray>().unwrap();
4201+
assert_eq!(col.len(), union_array.len());
4202+
assert_eq!(col.type_ids(), union_array.type_ids());
4203+
4204+
for i in 0..col.len() {
4205+
assert_eq!(col.value(i).as_ref(), union_array.value(i).as_ref());
4206+
}
4207+
}
40534208
}

0 commit comments

Comments
 (0)