1919//!
2020//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
2121
22- use serde:: { Deserialize , Serialize } ;
22+ use serde_core:: de:: { self , MapAccess , Visitor } ;
23+ use serde_core:: ser:: SerializeStruct ;
24+ use serde_core:: { Deserialize , Deserializer , Serialize , Serializer } ;
25+ use std:: fmt;
2326
2427use crate :: { ArrowError , DataType , extension:: ExtensionType } ;
2528
@@ -129,7 +132,7 @@ impl FixedShapeTensor {
129132}
130133
131134/// Extension type metadata for [`FixedShapeTensor`].
132- #[ derive( Debug , Clone , PartialEq , Deserialize , Serialize ) ]
135+ #[ derive( Debug , Clone , PartialEq ) ]
133136pub struct FixedShapeTensorMetadata {
134137 /// The physical shape of the contained tensors.
135138 shape : Vec < usize > ,
@@ -141,6 +144,143 @@ pub struct FixedShapeTensorMetadata {
141144 permutations : Option < Vec < usize > > ,
142145}
143146
147+ impl Serialize for FixedShapeTensorMetadata {
148+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
149+ where
150+ S : Serializer ,
151+ {
152+ let mut state = serializer. serialize_struct ( "FixedShapeTensorMetadata" , 3 ) ?;
153+ state. serialize_field ( "shape" , & self . shape ) ?;
154+ state. serialize_field ( "dim_names" , & self . dim_names ) ?;
155+ state. serialize_field ( "permutations" , & self . permutations ) ?;
156+ state. end ( )
157+ }
158+ }
159+
160+ #[ derive( Debug ) ]
161+ enum MetadataField {
162+ Shape ,
163+ DimNames ,
164+ Permutations ,
165+ }
166+
167+ struct MetadataFieldVisitor ;
168+
169+ impl < ' de > Visitor < ' de > for MetadataFieldVisitor {
170+ type Value = MetadataField ;
171+
172+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
173+ formatter. write_str ( "`shape`, `dim_names`, or `permutations`" )
174+ }
175+
176+ fn visit_str < E > ( self , value : & str ) -> Result < MetadataField , E >
177+ where
178+ E : de:: Error ,
179+ {
180+ match value {
181+ "shape" => Ok ( MetadataField :: Shape ) ,
182+ "dim_names" => Ok ( MetadataField :: DimNames ) ,
183+ "permutations" => Ok ( MetadataField :: Permutations ) ,
184+ _ => Err ( de:: Error :: unknown_field (
185+ value,
186+ & [ "shape" , "dim_names" , "permutations" ] ,
187+ ) ) ,
188+ }
189+ }
190+ }
191+
192+ impl < ' de > Deserialize < ' de > for MetadataField {
193+ fn deserialize < D > ( deserializer : D ) -> Result < MetadataField , D :: Error >
194+ where
195+ D : Deserializer < ' de > ,
196+ {
197+ deserializer. deserialize_identifier ( MetadataFieldVisitor )
198+ }
199+ }
200+
201+ struct FixedShapeTensorMetadataVisitor ;
202+
203+ impl < ' de > Visitor < ' de > for FixedShapeTensorMetadataVisitor {
204+ type Value = FixedShapeTensorMetadata ;
205+
206+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
207+ formatter. write_str ( "struct FixedShapeTensorMetadata" )
208+ }
209+
210+ fn visit_seq < V > ( self , mut seq : V ) -> Result < FixedShapeTensorMetadata , V :: Error >
211+ where
212+ V : de:: SeqAccess < ' de > ,
213+ {
214+ let shape = seq
215+ . next_element ( ) ?
216+ . ok_or_else ( || de:: Error :: invalid_length ( 0 , & self ) ) ?;
217+ let dim_names = seq
218+ . next_element ( ) ?
219+ . ok_or_else ( || de:: Error :: invalid_length ( 1 , & self ) ) ?;
220+ let permutations = seq
221+ . next_element ( ) ?
222+ . ok_or_else ( || de:: Error :: invalid_length ( 2 , & self ) ) ?;
223+ Ok ( FixedShapeTensorMetadata {
224+ shape,
225+ dim_names,
226+ permutations,
227+ } )
228+ }
229+
230+ fn visit_map < V > ( self , mut map : V ) -> Result < FixedShapeTensorMetadata , V :: Error >
231+ where
232+ V : MapAccess < ' de > ,
233+ {
234+ let mut shape = None ;
235+ let mut dim_names = None ;
236+ let mut permutations = None ;
237+
238+ while let Some ( key) = map. next_key ( ) ? {
239+ match key {
240+ MetadataField :: Shape => {
241+ if shape. is_some ( ) {
242+ return Err ( de:: Error :: duplicate_field ( "shape" ) ) ;
243+ }
244+ shape = Some ( map. next_value ( ) ?) ;
245+ }
246+ MetadataField :: DimNames => {
247+ if dim_names. is_some ( ) {
248+ return Err ( de:: Error :: duplicate_field ( "dim_names" ) ) ;
249+ }
250+ dim_names = Some ( map. next_value ( ) ?) ;
251+ }
252+ MetadataField :: Permutations => {
253+ if permutations. is_some ( ) {
254+ return Err ( de:: Error :: duplicate_field ( "permutations" ) ) ;
255+ }
256+ permutations = Some ( map. next_value ( ) ?) ;
257+ }
258+ }
259+ }
260+
261+ let shape = shape. ok_or_else ( || de:: Error :: missing_field ( "shape" ) ) ?;
262+
263+ Ok ( FixedShapeTensorMetadata {
264+ shape,
265+ dim_names,
266+ permutations,
267+ } )
268+ }
269+ }
270+
271+ impl < ' de > Deserialize < ' de > for FixedShapeTensorMetadata {
272+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
273+ where
274+ D : Deserializer < ' de > ,
275+ {
276+ deserializer. deserialize_struct (
277+ "FixedShapeTensorMetadata" ,
278+ & [ "shape" , "dim_names" , "permutations" ] ,
279+ FixedShapeTensorMetadataVisitor ,
280+ )
281+ }
282+ }
283+
144284impl FixedShapeTensorMetadata {
145285 /// Returns metadata for a fixed shape tensor extension type.
146286 ///
@@ -377,9 +517,8 @@ mod tests {
377517 }
378518
379519 #[ test]
380- #[ should_panic(
381- expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`"
382- ) ]
520+ #[ should_panic( expected = "FixedShapeTensor metadata deserialization failed: \
521+ unknown field `not-shape`, expected one of `shape`, `dim_names`, `permutations`") ]
383522 fn invalid_metadata ( ) {
384523 let fixed_shape_tensor =
385524 FixedShapeTensor :: try_new ( DataType :: Float32 , [ 100 , 200 , 500 ] , None , None ) . unwrap ( ) ;
0 commit comments