Skip to content

Commit a92c8a9

Browse files
authored
[Variant] Support Shredded Lists/Array in variant_get (#9049)
# Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. --> - Closes #8082. # Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> # What changes are included in this PR? - Move `ArrayVariantToArrowRowBuilder` from `shred_variant` to `variant_to_arrow` so it can be shared with `variant_get`. - Reorder definitions in `variant_to_arrow` to clarify the hierarchy: start with the top-level `VariantToArrowRowBuilder`, then second-level builders such as `PrimitiveVariantToArrowRowBuilder` and `ArrayVariantToArrowRowBuilder`, etc. - Add tests for `variant_get` with lists. # Are these changes tested? Yes # Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. If there are any breaking changes to public APIs, please call them out. --> `variant_get` now supports lists.
1 parent 618cab7 commit a92c8a9

File tree

3 files changed

+226
-27
lines changed

3 files changed

+226
-27
lines changed

parquet-variant-compute/src/shred_variant.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
274274

275275
fn append_null(&mut self) -> Result<()> {
276276
self.value_builder.append_value(Variant::Null);
277-
self.typed_value_builder.append_null();
277+
self.typed_value_builder.append_null()?;
278278
Ok(())
279279
}
280280

@@ -284,12 +284,13 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
284284
match variant {
285285
Variant::List(list) => {
286286
self.value_builder.append_null();
287-
self.typed_value_builder.append_value(list)?;
287+
self.typed_value_builder
288+
.append_value(&Variant::List(list))?;
288289
Ok(true)
289290
}
290291
other => {
291292
self.value_builder.append_value(other);
292-
self.typed_value_builder.append_null();
293+
self.typed_value_builder.append_null()?;
293294
Ok(false)
294295
}
295296
}

parquet-variant-compute/src/variant_get.rs

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,11 @@ mod test {
339339
Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
340340
Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
341341
Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
342-
LargeBinaryArray, LargeStringArray, NullBuilder, StringArray, StringViewArray, StructArray,
342+
LargeBinaryArray, LargeListArray, LargeListViewArray, LargeStringArray, ListArray,
343+
ListViewArray, NullBuilder, StringArray, StringViewArray, StructArray,
343344
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
344345
};
345-
use arrow::buffer::NullBuffer;
346+
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
346347
use arrow::compute::CastOptions;
347348
use arrow::datatypes::DataType::{Int16, Int32, Int64};
348349
use arrow::datatypes::i256;
@@ -351,8 +352,8 @@ mod test {
351352
use arrow_schema::{DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit};
352353
use chrono::DateTime;
353354
use parquet_variant::{
354-
EMPTY_VARIANT_METADATA_BYTES, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16,
355-
VariantDecimalType, VariantPath,
355+
EMPTY_VARIANT_METADATA_BYTES, Variant, VariantBuilder, VariantDecimal4, VariantDecimal8,
356+
VariantDecimal16, VariantDecimalType, VariantPath,
356357
};
357358

358359
fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) {
@@ -4158,4 +4159,182 @@ mod test {
41584159
assert!(inner_values_result.is_null(1));
41594160
assert_eq!(inner_values_result.value(2), 333);
41604161
}
4162+
4163+
#[test]
4164+
fn test_variant_get_list_like_safe_cast() {
4165+
let string_array: ArrayRef = Arc::new(StringArray::from(vec![
4166+
r#"[1, "two", 3]"#,
4167+
"\"not a list\"",
4168+
]));
4169+
let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap());
4170+
4171+
let value_array: ArrayRef = {
4172+
let mut builder = VariantBuilder::new();
4173+
builder.append_value("two");
4174+
let (_, value_bytes) = builder.finish();
4175+
Arc::new(BinaryViewArray::from(vec![
4176+
None,
4177+
Some(value_bytes.as_slice()),
4178+
None,
4179+
]))
4180+
};
4181+
let typed_value_array: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)]));
4182+
let struct_fields = Fields::from(vec![
4183+
Field::new("value", DataType::BinaryView, true),
4184+
Field::new("typed_value", DataType::Int64, true),
4185+
]);
4186+
let struct_array: ArrayRef = Arc::new(
4187+
StructArray::try_new(
4188+
struct_fields.clone(),
4189+
vec![value_array.clone(), typed_value_array.clone()],
4190+
None,
4191+
)
4192+
.unwrap(),
4193+
);
4194+
4195+
let request_field = Arc::new(Field::new("item", DataType::Int64, true));
4196+
let result_field = Arc::new(Field::new("item", DataType::Struct(struct_fields), true));
4197+
4198+
let expectations = vec![
4199+
(
4200+
DataType::List(request_field.clone()),
4201+
Arc::new(ListArray::new(
4202+
result_field.clone(),
4203+
OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 3])),
4204+
struct_array.clone(),
4205+
Some(NullBuffer::from(vec![true, false])),
4206+
)) as ArrayRef,
4207+
),
4208+
(
4209+
DataType::LargeList(request_field.clone()),
4210+
Arc::new(LargeListArray::new(
4211+
result_field.clone(),
4212+
OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 3])),
4213+
struct_array.clone(),
4214+
Some(NullBuffer::from(vec![true, false])),
4215+
)) as ArrayRef,
4216+
),
4217+
(
4218+
DataType::ListView(request_field.clone()),
4219+
Arc::new(ListViewArray::new(
4220+
result_field.clone(),
4221+
ScalarBuffer::from(vec![0, 3]),
4222+
ScalarBuffer::from(vec![3, 0]),
4223+
struct_array.clone(),
4224+
Some(NullBuffer::from(vec![true, false])),
4225+
)) as ArrayRef,
4226+
),
4227+
(
4228+
DataType::LargeListView(request_field),
4229+
Arc::new(LargeListViewArray::new(
4230+
result_field,
4231+
ScalarBuffer::from(vec![0, 3]),
4232+
ScalarBuffer::from(vec![3, 0]),
4233+
struct_array,
4234+
Some(NullBuffer::from(vec![true, false])),
4235+
)) as ArrayRef,
4236+
),
4237+
];
4238+
4239+
for (request_type, expected) in expectations {
4240+
let options = GetOptions::new().with_as_type(Some(FieldRef::from(Field::new(
4241+
"result",
4242+
request_type.clone(),
4243+
true,
4244+
))));
4245+
4246+
let result = variant_get(&variant_array, options).unwrap();
4247+
assert_eq!(result.data_type(), expected.data_type());
4248+
assert_eq!(&result, &expected);
4249+
}
4250+
}
4251+
4252+
#[test]
4253+
fn test_variant_get_list_like_unsafe_cast_errors_on_element_mismatch() {
4254+
let string_array: ArrayRef =
4255+
Arc::new(StringArray::from(vec![r#"[1, "two", 3]"#, "[4, 5]"]));
4256+
let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap());
4257+
let cast_options = CastOptions {
4258+
safe: false,
4259+
..Default::default()
4260+
};
4261+
4262+
let item_field = Arc::new(Field::new("item", DataType::Int64, true));
4263+
let request_types = vec![
4264+
DataType::List(item_field.clone()),
4265+
DataType::LargeList(item_field.clone()),
4266+
DataType::ListView(item_field.clone()),
4267+
DataType::LargeListView(item_field),
4268+
];
4269+
4270+
for request_type in request_types {
4271+
let options = GetOptions::new()
4272+
.with_as_type(Some(FieldRef::from(Field::new(
4273+
"result",
4274+
request_type.clone(),
4275+
true,
4276+
))))
4277+
.with_cast_options(cast_options.clone());
4278+
4279+
let err = variant_get(&variant_array, options).unwrap_err();
4280+
assert!(
4281+
err.to_string()
4282+
.contains("Failed to extract primitive of type Int64")
4283+
);
4284+
}
4285+
}
4286+
4287+
#[test]
4288+
fn test_variant_get_list_like_unsafe_cast_errors_on_non_list() {
4289+
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["[1, 2]", "\"not a list\""]));
4290+
let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap());
4291+
let cast_options = CastOptions {
4292+
safe: false,
4293+
..Default::default()
4294+
};
4295+
let item_field = Arc::new(Field::new("item", Int64, true));
4296+
let data_types = vec![
4297+
DataType::List(item_field.clone()),
4298+
DataType::LargeList(item_field.clone()),
4299+
DataType::ListView(item_field.clone()),
4300+
DataType::LargeListView(item_field),
4301+
];
4302+
4303+
for data_type in data_types {
4304+
let options = GetOptions::new()
4305+
.with_as_type(Some(FieldRef::from(Field::new("result", data_type, true))))
4306+
.with_cast_options(cast_options.clone());
4307+
4308+
let err = variant_get(&variant_array, options).unwrap_err();
4309+
assert!(
4310+
err.to_string()
4311+
.contains("Failed to extract list from variant"),
4312+
);
4313+
}
4314+
}
4315+
4316+
#[test]
4317+
fn test_variant_get_fixed_size_list_not_implemented() {
4318+
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["[1, 2]", "\"not a list\""]));
4319+
let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap());
4320+
let item_field = Arc::new(Field::new("item", Int64, true));
4321+
for safe in [true, false] {
4322+
let options = GetOptions::new()
4323+
.with_as_type(Some(FieldRef::from(Field::new(
4324+
"result",
4325+
DataType::FixedSizeList(item_field.clone(), 2),
4326+
true,
4327+
))))
4328+
.with_cast_options(CastOptions {
4329+
safe,
4330+
..Default::default()
4331+
});
4332+
4333+
let err = variant_get(&variant_array, options).unwrap_err();
4334+
assert!(
4335+
err.to_string()
4336+
.contains("Converting unshredded variant arrays to arrow fixed-size lists")
4337+
);
4338+
}
4339+
}
41614340
}

parquet-variant-compute/src/variant_to_arrow.rs

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use arrow::compute::{CastOptions, DecimalCast};
3434
use arrow::datatypes::{self, DataType, DecimalType};
3535
use arrow::error::{ArrowError, Result};
3636
use arrow_schema::{FieldRef, TimeUnit};
37-
use parquet_variant::{Variant, VariantList, VariantPath};
37+
use parquet_variant::{Variant, VariantPath};
3838
use std::sync::Arc;
3939

4040
/// Builder for converting variant values into strongly typed Arrow arrays.
@@ -43,6 +43,7 @@ use std::sync::Arc;
4343
/// with casting of leaf values to specific types.
4444
pub(crate) enum VariantToArrowRowBuilder<'a> {
4545
Primitive(PrimitiveVariantToArrowRowBuilder<'a>),
46+
Array(ArrayVariantToArrowRowBuilder<'a>),
4647
BinaryVariant(VariantToBinaryVariantArrowRowBuilder),
4748

4849
// Path extraction wrapper - contains a boxed enum for any of the above
@@ -54,6 +55,7 @@ impl<'a> VariantToArrowRowBuilder<'a> {
5455
use VariantToArrowRowBuilder::*;
5556
match self {
5657
Primitive(b) => b.append_null(),
58+
Array(b) => b.append_null(),
5759
BinaryVariant(b) => b.append_null(),
5860
WithPath(path_builder) => path_builder.append_null(),
5961
}
@@ -63,6 +65,7 @@ impl<'a> VariantToArrowRowBuilder<'a> {
6365
use VariantToArrowRowBuilder::*;
6466
match self {
6567
Primitive(b) => b.append_value(&value),
68+
Array(b) => b.append_value(&value),
6669
BinaryVariant(b) => b.append_value(value),
6770
WithPath(path_builder) => path_builder.append_value(value),
6871
}
@@ -72,6 +75,7 @@ impl<'a> VariantToArrowRowBuilder<'a> {
7275
use VariantToArrowRowBuilder::*;
7376
match self {
7477
Primitive(b) => b.finish(),
78+
Array(b) => b.finish(),
7579
BinaryVariant(b) => b.finish(),
7680
WithPath(path_builder) => path_builder.finish(),
7781
}
@@ -99,15 +103,15 @@ pub(crate) fn make_variant_to_arrow_row_builder<'a>(
99103
));
100104
}
101105
Some(
102-
DataType::List(_)
106+
data_type @ (DataType::List(_)
103107
| DataType::LargeList(_)
104108
| DataType::ListView(_)
105109
| DataType::LargeListView(_)
106-
| DataType::FixedSizeList(..),
110+
| DataType::FixedSizeList(..)),
107111
) => {
108-
return Err(ArrowError::NotYetImplemented(
109-
"Converting unshredded variant arrays to arrow lists".to_string(),
110-
));
112+
let builder =
113+
ArrayVariantToArrowRowBuilder::try_new(data_type, cast_options, capacity)?;
114+
Array(builder)
111115
}
112116
Some(data_type) => {
113117
let builder =
@@ -526,7 +530,7 @@ impl<'a> ArrayVariantToArrowRowBuilder<'a> {
526530
Ok(builder)
527531
}
528532

529-
pub(crate) fn append_null(&mut self) {
533+
pub(crate) fn append_null(&mut self) -> Result<()> {
530534
match self {
531535
Self::List(builder) => builder.append_null(),
532536
Self::LargeList(builder) => builder.append_null(),
@@ -535,12 +539,12 @@ impl<'a> ArrayVariantToArrowRowBuilder<'a> {
535539
}
536540
}
537541

538-
pub(crate) fn append_value(&mut self, list: VariantList<'_, '_>) -> Result<()> {
542+
pub(crate) fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
539543
match self {
540-
Self::List(builder) => builder.append_value(list),
541-
Self::LargeList(builder) => builder.append_value(list),
542-
Self::ListView(builder) => builder.append_value(list),
543-
Self::LargeListView(builder) => builder.append_value(list),
544+
Self::List(builder) => builder.append_value(value),
545+
Self::LargeList(builder) => builder.append_value(value),
546+
Self::ListView(builder) => builder.append_value(value),
547+
Self::LargeListView(builder) => builder.append_value(value),
544548
}
545549
}
546550

@@ -795,6 +799,7 @@ where
795799
element_builder: Box<VariantToShreddedVariantRowBuilder<'a>>,
796800
nulls: NullBufferBuilder,
797801
current_offset: O,
802+
cast_options: &'a CastOptions<'a>,
798803
}
799804

800805
impl<'a, O, const IS_VIEW: bool> VariantToListArrowRowBuilder<'a, O, IS_VIEW>
@@ -826,22 +831,36 @@ where
826831
element_builder: Box::new(element_builder),
827832
nulls: NullBufferBuilder::new(capacity),
828833
current_offset: O::ZERO,
834+
cast_options,
829835
})
830836
}
831837

832-
fn append_null(&mut self) {
838+
fn append_null(&mut self) -> Result<()> {
833839
self.offsets.push(self.current_offset);
834840
self.nulls.append_null();
841+
Ok(())
835842
}
836843

837-
fn append_value(&mut self, list: VariantList<'_, '_>) -> Result<()> {
838-
for element in list.iter() {
839-
self.element_builder.append_value(element)?;
840-
self.current_offset = self.current_offset.add_checked(O::ONE)?;
844+
fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
845+
match value {
846+
Variant::List(list) => {
847+
for element in list.iter() {
848+
self.element_builder.append_value(element)?;
849+
self.current_offset = self.current_offset.add_checked(O::ONE)?;
850+
}
851+
self.offsets.push(self.current_offset);
852+
self.nulls.append_non_null();
853+
Ok(true)
854+
}
855+
_ if self.cast_options.safe => {
856+
self.append_null()?;
857+
Ok(false)
858+
}
859+
_ => Err(ArrowError::CastError(format!(
860+
"Failed to extract list from variant {:?}",
861+
value
862+
))),
841863
}
842-
self.offsets.push(self.current_offset);
843-
self.nulls.append_non_null();
844-
Ok(())
845864
}
846865

847866
fn finish(mut self) -> Result<ArrayRef> {

0 commit comments

Comments
 (0)