@@ -20,12 +20,13 @@ use arrow::{
2020 datatypes:: Field ,
2121 error:: Result ,
2222} ;
23+ use arrow_schema:: extension:: ExtensionType ;
2324use arrow_schema:: { ArrowError , DataType , FieldRef } ;
2425use parquet_variant:: { VariantPath , VariantPathElement } ;
2526
26- use crate :: VariantArray ;
2727use crate :: variant_array:: BorrowedShreddingState ;
2828use crate :: variant_to_arrow:: make_variant_to_arrow_row_builder;
29+ use crate :: { VariantArray , VariantType , unshred_variant} ;
2930
3031use arrow:: array:: AsArray ;
3132use std:: sync:: Arc ;
@@ -109,6 +110,11 @@ pub(crate) fn follow_shredded_path_element<'a>(
109110 }
110111}
111112
113+ fn is_variant_extension ( field : & Field ) -> bool {
114+ field. extension_type_name ( ) == Some ( VariantType :: NAME )
115+ && field. try_extension_type :: < VariantType > ( ) . is_ok ( )
116+ }
117+
112118/// Follows the given path as far as possible through shredded variant fields. If the path ends on a
113119/// shredded field, return it directly. Otherwise, use a row shredder to follow the rest of the path
114120/// and extract the requested value on a per-row basis.
@@ -131,7 +137,22 @@ fn shredded_get_path(
131137 // Helper that shreds a VariantArray to a specific type.
132138 let shred_basic_variant =
133139 |target : VariantArray , path : VariantPath < ' _ > , as_field : Option < & Field > | {
134- let as_type = as_field. map ( |f| f. data_type ( ) ) ;
140+ let requested_variant = as_field. is_some_and ( is_variant_extension) ;
141+ let target = if requested_variant {
142+ unshred_variant ( & target) ?
143+ } else {
144+ target
145+ } ;
146+
147+ if requested_variant && path. is_empty ( ) {
148+ return Ok ( ArrayRef :: from ( target) ) ;
149+ }
150+
151+ let as_type = if requested_variant {
152+ None
153+ } else {
154+ as_field. map ( |f| f. data_type ( ) )
155+ } ;
135156 let mut builder = make_variant_to_arrow_row_builder (
136157 target. metadata_field ( ) ,
137158 path,
@@ -179,6 +200,16 @@ fn shredded_get_path(
179200 }
180201 ShreddedPathStep :: Missing => {
181202 let num_rows = input. len ( ) ;
203+ if as_field. is_some_and ( is_variant_extension) {
204+ let all_nulls = Some ( arrow:: buffer:: NullBuffer :: from ( vec ! [ false ; num_rows] ) ) ;
205+ let arr = VariantArray :: from_parts (
206+ input. metadata_field ( ) . clone ( ) ,
207+ None ,
208+ None ,
209+ all_nulls,
210+ ) ;
211+ return Ok ( ArrayRef :: from ( arr) ) ;
212+ }
182213 let arr = match as_field. map ( |f| f. data_type ( ) ) {
183214 Some ( data_type) => array:: new_null_array ( data_type, num_rows) ,
184215 None => Arc :: new ( array:: NullArray :: new ( num_rows) ) as _ ,
@@ -222,7 +253,9 @@ fn shredded_get_path(
222253 //
223254 // For shredded/partially-shredded targets (`typed_value` present), recurse into each field
224255 // separately to take advantage of deeper shredding in child fields.
225- if let DataType :: Struct ( fields) = as_field. data_type ( ) {
256+ if !is_variant_extension ( as_field)
257+ && let DataType :: Struct ( fields) = as_field. data_type ( )
258+ {
226259 if target. typed_value_field ( ) . is_none ( ) {
227260 return shred_basic_variant ( target, VariantPath :: default ( ) , Some ( as_field) ) ;
228261 }
@@ -2038,6 +2071,63 @@ mod test {
20382071 println ! ( "Nested path 'a.x' result: {:?}" , result) ;
20392072 }
20402073
2074+ #[ test]
2075+ fn test_variant_get_as_variant_from_unshredded_input ( ) {
2076+ let ( unshredded, _) = create_variant_get_as_variant_test_data ( ) ;
2077+ assert_variant_field_extraction_returns_unshredded_variant ( & unshredded) ;
2078+ }
2079+
2080+ #[ test]
2081+ fn test_variant_get_as_variant_from_shredded_input ( ) {
2082+ let ( _, shredded) = create_variant_get_as_variant_test_data ( ) ;
2083+ assert_variant_field_extraction_returns_unshredded_variant ( & shredded) ;
2084+ }
2085+
2086+ fn create_variant_get_as_variant_test_data ( ) -> ( ArrayRef , ArrayRef ) {
2087+ let input_json: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [
2088+ Some ( r#"{"field_name": {"k": 100000}}"# ) ,
2089+ Some ( r#"{"field_name": {"k": "s"}}"# ) ,
2090+ ] ) ) ;
2091+
2092+ let unshredded = ArrayRef :: from ( json_to_variant ( & input_json) . unwrap ( ) ) ;
2093+ let unshredded_variant = VariantArray :: try_new ( & unshredded) . unwrap ( ) ;
2094+
2095+ let as_type = DataType :: Struct ( Fields :: from ( vec ! [ Field :: new(
2096+ "field_name" ,
2097+ DataType :: Struct ( Fields :: from( vec![ Field :: new( "k" , DataType :: Int32 , true ) ] ) ) ,
2098+ true ,
2099+ ) ] ) ) ;
2100+ let shredded = ArrayRef :: from ( shred_variant ( & unshredded_variant, & as_type) . unwrap ( ) ) ;
2101+
2102+ ( unshredded, shredded)
2103+ }
2104+
2105+ fn assert_variant_field_extraction_returns_unshredded_variant ( input : & ArrayRef ) {
2106+ let variant_output = VariantArray :: try_new ( input) . unwrap ( ) . field ( "result" ) ;
2107+ let options = GetOptions :: new_with_path ( VariantPath :: try_from ( "field_name" ) . unwrap ( ) )
2108+ . with_as_type ( Some ( FieldRef :: from ( variant_output) ) ) ;
2109+
2110+ let result = variant_get ( input, options) . unwrap ( ) ;
2111+ let result_variant = VariantArray :: try_new ( & result) . unwrap ( ) ;
2112+
2113+ assert ! ( result_variant. typed_value_field( ) . is_none( ) ) ;
2114+ assert ! ( result_variant. value_field( ) . is_some( ) ) ;
2115+
2116+ let expected_json: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [
2117+ Some ( r#"{"k":100000}"# ) ,
2118+ Some ( r#"{"k":"s"}"# ) ,
2119+ ] ) ) ;
2120+ let expected = json_to_variant ( & expected_json) . unwrap ( ) ;
2121+
2122+ assert_eq ! ( result_variant. len( ) , expected. len( ) ) ;
2123+ for i in 0 ..result_variant. len ( ) {
2124+ assert_eq ! ( result_variant. is_null( i) , expected. is_null( i) ) ;
2125+ if !result_variant. is_null ( i) {
2126+ assert_eq ! ( result_variant. value( i) , expected. value( i) ) ;
2127+ }
2128+ }
2129+ }
2130+
20412131 /// Create test data for depth 0 (direct field access)
20422132 /// [{"x": 42}, {"x": "foo"}, {"y": 10}]
20432133 fn create_depth_0_test_data ( ) -> ArrayRef {
0 commit comments