Skip to content

Commit 1bc7431

Browse files
committed
Add variant_get access as Variant
1 parent a1bf90c commit 1bc7431

File tree

1 file changed

+93
-3
lines changed

1 file changed

+93
-3
lines changed

parquet-variant-compute/src/variant_get.rs

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ use arrow::{
2020
datatypes::Field,
2121
error::Result,
2222
};
23+
use arrow_schema::extension::ExtensionType;
2324
use arrow_schema::{ArrowError, DataType, FieldRef};
2425
use parquet_variant::{VariantPath, VariantPathElement};
2526

26-
use crate::VariantArray;
2727
use crate::variant_array::BorrowedShreddingState;
2828
use crate::variant_to_arrow::make_variant_to_arrow_row_builder;
29+
use crate::{VariantArray, VariantType, unshred_variant};
2930

3031
use arrow::array::AsArray;
3132
use std::sync::Arc;
@@ -109,6 +110,11 @@ pub(crate) fn follow_shredded_path_element<'a>(
109110
}
110111
}
111112

113+
fn is_variant_extension(field: &Field) -> bool {
114+
field.extension_type_name() == Some(VariantType::NAME)
115+
&& field.try_extension_type::<VariantType>().is_ok()
116+
}
117+
112118
/// Follows the given path as far as possible through shredded variant fields. If the path ends on a
113119
/// shredded field, return it directly. Otherwise, use a row shredder to follow the rest of the path
114120
/// and extract the requested value on a per-row basis.
@@ -131,7 +137,22 @@ fn shredded_get_path(
131137
// Helper that shreds a VariantArray to a specific type.
132138
let shred_basic_variant =
133139
|target: VariantArray, path: VariantPath<'_>, as_field: Option<&Field>| {
134-
let as_type = as_field.map(|f| f.data_type());
140+
let requested_variant = as_field.is_some_and(is_variant_extension);
141+
let target = if requested_variant {
142+
unshred_variant(&target)?
143+
} else {
144+
target
145+
};
146+
147+
if requested_variant && path.is_empty() {
148+
return Ok(ArrayRef::from(target));
149+
}
150+
151+
let as_type = if requested_variant {
152+
None
153+
} else {
154+
as_field.map(|f| f.data_type())
155+
};
135156
let mut builder = make_variant_to_arrow_row_builder(
136157
target.metadata_field(),
137158
path,
@@ -179,6 +200,16 @@ fn shredded_get_path(
179200
}
180201
ShreddedPathStep::Missing => {
181202
let num_rows = input.len();
203+
if as_field.is_some_and(is_variant_extension) {
204+
let all_nulls = Some(arrow::buffer::NullBuffer::from(vec![false; num_rows]));
205+
let arr = VariantArray::from_parts(
206+
input.metadata_field().clone(),
207+
None,
208+
None,
209+
all_nulls,
210+
);
211+
return Ok(ArrayRef::from(arr));
212+
}
182213
let arr = match as_field.map(|f| f.data_type()) {
183214
Some(data_type) => array::new_null_array(data_type, num_rows),
184215
None => Arc::new(array::NullArray::new(num_rows)) as _,
@@ -222,7 +253,9 @@ fn shredded_get_path(
222253
//
223254
// For shredded/partially-shredded targets (`typed_value` present), recurse into each field
224255
// separately to take advantage of deeper shredding in child fields.
225-
if let DataType::Struct(fields) = as_field.data_type() {
256+
if !is_variant_extension(as_field)
257+
&& let DataType::Struct(fields) = as_field.data_type()
258+
{
226259
if target.typed_value_field().is_none() {
227260
return shred_basic_variant(target, VariantPath::default(), Some(as_field));
228261
}
@@ -2038,6 +2071,63 @@ mod test {
20382071
println!("Nested path 'a.x' result: {:?}", result);
20392072
}
20402073

2074+
#[test]
2075+
fn test_variant_get_as_variant_from_unshredded_input() {
2076+
let (unshredded, _) = create_variant_get_as_variant_test_data();
2077+
assert_variant_field_extraction_returns_unshredded_variant(&unshredded);
2078+
}
2079+
2080+
#[test]
2081+
fn test_variant_get_as_variant_from_shredded_input() {
2082+
let (_, shredded) = create_variant_get_as_variant_test_data();
2083+
assert_variant_field_extraction_returns_unshredded_variant(&shredded);
2084+
}
2085+
2086+
fn create_variant_get_as_variant_test_data() -> (ArrayRef, ArrayRef) {
2087+
let input_json: ArrayRef = Arc::new(StringArray::from(vec![
2088+
Some(r#"{"field_name": {"k": 100000}}"#),
2089+
Some(r#"{"field_name": {"k": "s"}}"#),
2090+
]));
2091+
2092+
let unshredded = ArrayRef::from(json_to_variant(&input_json).unwrap());
2093+
let unshredded_variant = VariantArray::try_new(&unshredded).unwrap();
2094+
2095+
let as_type = DataType::Struct(Fields::from(vec![Field::new(
2096+
"field_name",
2097+
DataType::Struct(Fields::from(vec![Field::new("k", DataType::Int32, true)])),
2098+
true,
2099+
)]));
2100+
let shredded = ArrayRef::from(shred_variant(&unshredded_variant, &as_type).unwrap());
2101+
2102+
(unshredded, shredded)
2103+
}
2104+
2105+
fn assert_variant_field_extraction_returns_unshredded_variant(input: &ArrayRef) {
2106+
let variant_output = VariantArray::try_new(input).unwrap().field("result");
2107+
let options = GetOptions::new_with_path(VariantPath::try_from("field_name").unwrap())
2108+
.with_as_type(Some(FieldRef::from(variant_output)));
2109+
2110+
let result = variant_get(input, options).unwrap();
2111+
let result_variant = VariantArray::try_new(&result).unwrap();
2112+
2113+
assert!(result_variant.typed_value_field().is_none());
2114+
assert!(result_variant.value_field().is_some());
2115+
2116+
let expected_json: ArrayRef = Arc::new(StringArray::from(vec![
2117+
Some(r#"{"k":100000}"#),
2118+
Some(r#"{"k":"s"}"#),
2119+
]));
2120+
let expected = json_to_variant(&expected_json).unwrap();
2121+
2122+
assert_eq!(result_variant.len(), expected.len());
2123+
for i in 0..result_variant.len() {
2124+
assert_eq!(result_variant.is_null(i), expected.is_null(i));
2125+
if !result_variant.is_null(i) {
2126+
assert_eq!(result_variant.value(i), expected.value(i));
2127+
}
2128+
}
2129+
}
2130+
20412131
/// Create test data for depth 0 (direct field access)
20422132
/// [{"x": 42}, {"x": "foo"}, {"y": 10}]
20432133
fn create_depth_0_test_data() -> ArrayRef {

0 commit comments

Comments
 (0)