diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index 7f253d249dfb..d82eb15af598 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -274,7 +274,7 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> { fn append_null(&mut self) -> Result<()> { self.value_builder.append_value(Variant::Null); - self.typed_value_builder.append_null(); + self.typed_value_builder.append_null()?; Ok(()) } @@ -284,12 +284,13 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> { match variant { Variant::List(list) => { self.value_builder.append_null(); - self.typed_value_builder.append_value(list)?; + self.typed_value_builder + .append_value(&Variant::List(list))?; Ok(true) } other => { self.value_builder.append_value(other); - self.typed_value_builder.append_null(); + self.typed_value_builder.append_null()?; Ok(false) } } diff --git a/parquet-variant-compute/src/variant_get.rs b/parquet-variant-compute/src/variant_get.rs index 624c8ae128dc..0c3599b17d68 100644 --- a/parquet-variant-compute/src/variant_get.rs +++ b/parquet-variant-compute/src/variant_get.rs @@ -339,10 +339,11 @@ mod test { Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, - LargeBinaryArray, LargeStringArray, NullBuilder, StringArray, StringViewArray, StructArray, + LargeBinaryArray, LargeListArray, LargeListViewArray, LargeStringArray, ListArray, + ListViewArray, NullBuilder, StringArray, StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, }; - use arrow::buffer::NullBuffer; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::compute::CastOptions; use arrow::datatypes::DataType::{Int16, Int32, Int64}; use arrow::datatypes::i256; @@ -351,8 +352,8 @@ mod test { use arrow_schema::{DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit}; use chrono::DateTime; use parquet_variant::{ - EMPTY_VARIANT_METADATA_BYTES, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16, - VariantDecimalType, VariantPath, + EMPTY_VARIANT_METADATA_BYTES, Variant, VariantBuilder, VariantDecimal4, VariantDecimal8, + VariantDecimal16, VariantDecimalType, VariantPath, }; fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) { @@ -4158,4 +4159,182 @@ mod test { assert!(inner_values_result.is_null(1)); assert_eq!(inner_values_result.value(2), 333); } + + #[test] + fn test_variant_get_list_like_safe_cast() { + let string_array: ArrayRef = Arc::new(StringArray::from(vec![ + r#"[1, "two", 3]"#, + "\"not a list\"", + ])); + let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap()); + + let value_array: ArrayRef = { + let mut builder = VariantBuilder::new(); + builder.append_value("two"); + let (_, value_bytes) = builder.finish(); + Arc::new(BinaryViewArray::from(vec![ + None, + Some(value_bytes.as_slice()), + None, + ])) + }; + let typed_value_array: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])); + let struct_fields = Fields::from(vec![ + Field::new("value", DataType::BinaryView, true), + Field::new("typed_value", DataType::Int64, true), + ]); + let struct_array: ArrayRef = Arc::new( + StructArray::try_new( + struct_fields.clone(), + vec![value_array.clone(), typed_value_array.clone()], + None, + ) + .unwrap(), + ); + + let request_field = Arc::new(Field::new("item", DataType::Int64, true)); + let result_field = Arc::new(Field::new("item", DataType::Struct(struct_fields), true)); + + let expectations = vec![ + ( + DataType::List(request_field.clone()), + Arc::new(ListArray::new( + result_field.clone(), + OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 3])), + struct_array.clone(), + Some(NullBuffer::from(vec![true, false])), + )) as ArrayRef, + ), + ( + DataType::LargeList(request_field.clone()), + Arc::new(LargeListArray::new( + result_field.clone(), + OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 3])), + struct_array.clone(), + Some(NullBuffer::from(vec![true, false])), + )) as ArrayRef, + ), + ( + DataType::ListView(request_field.clone()), + Arc::new(ListViewArray::new( + result_field.clone(), + ScalarBuffer::from(vec![0, 3]), + ScalarBuffer::from(vec![3, 0]), + struct_array.clone(), + Some(NullBuffer::from(vec![true, false])), + )) as ArrayRef, + ), + ( + DataType::LargeListView(request_field), + Arc::new(LargeListViewArray::new( + result_field, + ScalarBuffer::from(vec![0, 3]), + ScalarBuffer::from(vec![3, 0]), + struct_array, + Some(NullBuffer::from(vec![true, false])), + )) as ArrayRef, + ), + ]; + + for (request_type, expected) in expectations { + let options = GetOptions::new().with_as_type(Some(FieldRef::from(Field::new( + "result", + request_type.clone(), + true, + )))); + + let result = variant_get(&variant_array, options).unwrap(); + assert_eq!(result.data_type(), expected.data_type()); + assert_eq!(&result, &expected); + } + } + + #[test] + fn test_variant_get_list_like_unsafe_cast_errors_on_element_mismatch() { + let string_array: ArrayRef = + Arc::new(StringArray::from(vec![r#"[1, "two", 3]"#, "[4, 5]"])); + let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap()); + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + let item_field = Arc::new(Field::new("item", DataType::Int64, true)); + let request_types = vec![ + DataType::List(item_field.clone()), + DataType::LargeList(item_field.clone()), + DataType::ListView(item_field.clone()), + DataType::LargeListView(item_field), + ]; + + for request_type in request_types { + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(Field::new( + "result", + request_type.clone(), + true, + )))) + .with_cast_options(cast_options.clone()); + + let err = variant_get(&variant_array, options).unwrap_err(); + assert!( + err.to_string() + .contains("Failed to extract primitive of type Int64") + ); + } + } + + #[test] + fn test_variant_get_list_like_unsafe_cast_errors_on_non_list() { + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["[1, 2]", "\"not a list\""])); + let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap()); + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + let item_field = Arc::new(Field::new("item", Int64, true)); + let data_types = vec![ + DataType::List(item_field.clone()), + DataType::LargeList(item_field.clone()), + DataType::ListView(item_field.clone()), + DataType::LargeListView(item_field), + ]; + + for data_type in data_types { + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(Field::new("result", data_type, true)))) + .with_cast_options(cast_options.clone()); + + let err = variant_get(&variant_array, options).unwrap_err(); + assert!( + err.to_string() + .contains("Failed to extract list from variant"), + ); + } + } + + #[test] + fn test_variant_get_fixed_size_list_not_implemented() { + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["[1, 2]", "\"not a list\""])); + let variant_array = ArrayRef::from(json_to_variant(&string_array).unwrap()); + let item_field = Arc::new(Field::new("item", Int64, true)); + for safe in [true, false] { + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(Field::new( + "result", + DataType::FixedSizeList(item_field.clone(), 2), + true, + )))) + .with_cast_options(CastOptions { + safe, + ..Default::default() + }); + + let err = variant_get(&variant_array, options).unwrap_err(); + assert!( + err.to_string() + .contains("Converting unshredded variant arrays to arrow fixed-size lists") + ); + } + } } diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index 172bd4811bc3..17a750eed26e 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -34,7 +34,7 @@ use arrow::compute::{CastOptions, DecimalCast}; use arrow::datatypes::{self, DataType, DecimalType}; use arrow::error::{ArrowError, Result}; use arrow_schema::{FieldRef, TimeUnit}; -use parquet_variant::{Variant, VariantList, VariantPath}; +use parquet_variant::{Variant, VariantPath}; use std::sync::Arc; /// Builder for converting variant values into strongly typed Arrow arrays. @@ -43,6 +43,7 @@ use std::sync::Arc; /// with casting of leaf values to specific types. pub(crate) enum VariantToArrowRowBuilder<'a> { Primitive(PrimitiveVariantToArrowRowBuilder<'a>), + Array(ArrayVariantToArrowRowBuilder<'a>), BinaryVariant(VariantToBinaryVariantArrowRowBuilder), // Path extraction wrapper - contains a boxed enum for any of the above @@ -54,6 +55,7 @@ impl<'a> VariantToArrowRowBuilder<'a> { use VariantToArrowRowBuilder::*; match self { Primitive(b) => b.append_null(), + Array(b) => b.append_null(), BinaryVariant(b) => b.append_null(), WithPath(path_builder) => path_builder.append_null(), } @@ -63,6 +65,7 @@ impl<'a> VariantToArrowRowBuilder<'a> { use VariantToArrowRowBuilder::*; match self { Primitive(b) => b.append_value(&value), + Array(b) => b.append_value(&value), BinaryVariant(b) => b.append_value(value), WithPath(path_builder) => path_builder.append_value(value), } @@ -72,6 +75,7 @@ impl<'a> VariantToArrowRowBuilder<'a> { use VariantToArrowRowBuilder::*; match self { Primitive(b) => b.finish(), + Array(b) => b.finish(), BinaryVariant(b) => b.finish(), WithPath(path_builder) => path_builder.finish(), } @@ -99,15 +103,15 @@ pub(crate) fn make_variant_to_arrow_row_builder<'a>( )); } Some( - DataType::List(_) + data_type @ (DataType::List(_) | DataType::LargeList(_) | DataType::ListView(_) | DataType::LargeListView(_) - | DataType::FixedSizeList(..), + | DataType::FixedSizeList(..)), ) => { - return Err(ArrowError::NotYetImplemented( - "Converting unshredded variant arrays to arrow lists".to_string(), - )); + let builder = + ArrayVariantToArrowRowBuilder::try_new(data_type, cast_options, capacity)?; + Array(builder) } Some(data_type) => { let builder = @@ -526,7 +530,7 @@ impl<'a> ArrayVariantToArrowRowBuilder<'a> { Ok(builder) } - pub(crate) fn append_null(&mut self) { + pub(crate) fn append_null(&mut self) -> Result<()> { match self { Self::List(builder) => builder.append_null(), Self::LargeList(builder) => builder.append_null(), @@ -535,12 +539,12 @@ impl<'a> ArrayVariantToArrowRowBuilder<'a> { } } - pub(crate) fn append_value(&mut self, list: VariantList<'_, '_>) -> Result<()> { + pub(crate) fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { match self { - Self::List(builder) => builder.append_value(list), - Self::LargeList(builder) => builder.append_value(list), - Self::ListView(builder) => builder.append_value(list), - Self::LargeListView(builder) => builder.append_value(list), + Self::List(builder) => builder.append_value(value), + Self::LargeList(builder) => builder.append_value(value), + Self::ListView(builder) => builder.append_value(value), + Self::LargeListView(builder) => builder.append_value(value), } } @@ -795,6 +799,7 @@ where element_builder: Box>, nulls: NullBufferBuilder, current_offset: O, + cast_options: &'a CastOptions<'a>, } impl<'a, O, const IS_VIEW: bool> VariantToListArrowRowBuilder<'a, O, IS_VIEW> @@ -826,22 +831,36 @@ where element_builder: Box::new(element_builder), nulls: NullBufferBuilder::new(capacity), current_offset: O::ZERO, + cast_options, }) } - fn append_null(&mut self) { + fn append_null(&mut self) -> Result<()> { self.offsets.push(self.current_offset); self.nulls.append_null(); + Ok(()) } - fn append_value(&mut self, list: VariantList<'_, '_>) -> Result<()> { - for element in list.iter() { - self.element_builder.append_value(element)?; - self.current_offset = self.current_offset.add_checked(O::ONE)?; + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + match value { + Variant::List(list) => { + for element in list.iter() { + self.element_builder.append_value(element)?; + self.current_offset = self.current_offset.add_checked(O::ONE)?; + } + self.offsets.push(self.current_offset); + self.nulls.append_non_null(); + Ok(true) + } + _ if self.cast_options.safe => { + self.append_null()?; + Ok(false) + } + _ => Err(ArrowError::CastError(format!( + "Failed to extract list from variant {:?}", + value + ))), } - self.offsets.push(self.current_offset); - self.nulls.append_non_null(); - Ok(()) } fn finish(mut self) -> Result {