@@ -23,9 +23,10 @@ use crate::{
2323 encode:: { encode_int, encode_long} ,
2424 error:: { Details , Error } ,
2525 schema:: { Name , NamesRef , Namespace , RecordField , RecordSchema , Schema } ,
26+ serde:: util:: StringSerializer ,
2627} ;
2728use bigdecimal:: BigDecimal ;
28- use serde:: ser;
29+ use serde:: { Serialize , ser} ;
2930use std:: { borrow:: Cow , cmp:: Ordering , collections:: HashMap , io:: Write , str:: FromStr } ;
3031
3132const COLLECTION_SERIALIZER_ITEM_LIMIT : usize = 1024 ;
@@ -251,6 +252,8 @@ pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
251252 record_schema : & ' s RecordSchema ,
252253 /// Fields we received in the wrong order
253254 field_cache : HashMap < usize , Vec < u8 > > ,
255+ /// The current field name when serializing from a map (for `flatten` support).
256+ map_field_name : Option < String > ,
254257 field_position : usize ,
255258 bytes_written : usize ,
256259}
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
264267 ser,
265268 record_schema,
266269 field_cache : HashMap :: new ( ) ,
270+ map_field_name : None ,
267271 field_position : 0 ,
268272 bytes_written : 0 ,
269273 }
@@ -352,6 +356,10 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
352356 "There should be no more unwritten fields at this point: {:?}" ,
353357 self . field_cache
354358 ) ;
359+ assert ! (
360+ self . map_field_name. is_none( ) ,
361+ "There should be no field name at this point"
362+ ) ;
355363 Ok ( self . bytes_written )
356364 }
357365}
@@ -371,17 +379,14 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
371379 . and_then ( |idx| self . record_schema . fields . get ( * idx) ) ;
372380
373381 match record_field {
374- Some ( field) => {
375- // self.item_count += 1;
376- self . serialize_next_field ( field, value) . map_err ( |e| {
377- Details :: SerializeRecordFieldWithSchema {
378- field_name : key. to_string ( ) ,
379- record_schema : Schema :: Record ( self . record_schema . clone ( ) ) ,
380- error : Box :: new ( e) ,
381- }
382- . into ( )
383- } )
384- }
382+ Some ( field) => self . serialize_next_field ( field, value) . map_err ( |e| {
383+ Details :: SerializeRecordFieldWithSchema {
384+ field_name : key. to_string ( ) ,
385+ record_schema : Schema :: Record ( self . record_schema . clone ( ) ) ,
386+ error : Box :: new ( e) ,
387+ }
388+ . into ( )
389+ } ) ,
385390 None => Err ( Details :: FieldName ( String :: from ( key) ) . into ( ) ) ,
386391 }
387392 }
@@ -420,6 +425,50 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
420425 }
421426}
422427
428+ impl < W : Write > ser:: SerializeMap for SchemaAwareWriteSerializeStruct < ' _ , ' _ , W > {
429+ type Ok = usize ;
430+ type Error = Error ;
431+
432+ fn serialize_key < T > ( & mut self , key : & T ) -> Result < ( ) , Self :: Error >
433+ where
434+ T : ?Sized + Serialize ,
435+ {
436+ let name = key. serialize ( StringSerializer ) ?;
437+ assert ! (
438+ self . map_field_name. replace( name) . is_none( ) ,
439+ "Got two keys in a row"
440+ ) ;
441+ Ok ( ( ) )
442+ }
443+
444+ fn serialize_value < T > ( & mut self , value : & T ) -> Result < ( ) , Self :: Error >
445+ where
446+ T : ?Sized + Serialize ,
447+ {
448+ let key = self . map_field_name . take ( ) . expect ( "Got value without key" ) ;
449+ let record_field = self
450+ . record_schema
451+ . lookup
452+ . get ( & key)
453+ . and_then ( |idx| self . record_schema . fields . get ( * idx) ) ;
454+ match record_field {
455+ Some ( field) => self . serialize_next_field ( field, value) . map_err ( |e| {
456+ Details :: SerializeRecordFieldWithSchema {
457+ field_name : key. to_string ( ) ,
458+ record_schema : Schema :: Record ( self . record_schema . clone ( ) ) ,
459+ error : Box :: new ( e) ,
460+ }
461+ . into ( )
462+ } ) ,
463+ None => Err ( Details :: FieldName ( key) . into ( ) ) ,
464+ }
465+ }
466+
467+ fn end ( self ) -> Result < Self :: Ok , Self :: Error > {
468+ self . end ( )
469+ }
470+ }
471+
423472impl < W : Write > ser:: SerializeStructVariant for SchemaAwareWriteSerializeStruct < ' _ , ' _ , W > {
424473 type Ok = usize ;
425474 type Error = Error ;
@@ -436,6 +485,46 @@ impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'
436485 }
437486}
438487
488+ /// Map serializer that switches between Struct or Map.
489+ ///
490+ /// This exists because when `#[serde(flatten)]` is used, struct fields are serialized as a map.
491+ pub enum SchemaAwareWriteSerializeMapOrStruct < ' a , ' s , W : Write > {
492+ Struct ( SchemaAwareWriteSerializeStruct < ' a , ' s , W > ) ,
493+ Map ( SchemaAwareWriteSerializeMap < ' a , ' s , W > ) ,
494+ }
495+
496+ impl < W : Write > ser:: SerializeMap for SchemaAwareWriteSerializeMapOrStruct < ' _ , ' _ , W > {
497+ type Ok = usize ;
498+ type Error = Error ;
499+
500+ fn serialize_key < T > ( & mut self , key : & T ) -> Result < ( ) , Self :: Error >
501+ where
502+ T : ?Sized + Serialize ,
503+ {
504+ match self {
505+ Self :: Struct ( s) => s. serialize_key ( key) ,
506+ Self :: Map ( s) => s. serialize_key ( key) ,
507+ }
508+ }
509+
510+ fn serialize_value < T > ( & mut self , value : & T ) -> Result < ( ) , Self :: Error >
511+ where
512+ T : ?Sized + Serialize ,
513+ {
514+ match self {
515+ Self :: Struct ( s) => s. serialize_value ( value) ,
516+ Self :: Map ( s) => s. serialize_value ( value) ,
517+ }
518+ }
519+
520+ fn end ( self ) -> Result < Self :: Ok , Self :: Error > {
521+ match self {
522+ Self :: Struct ( s) => s. end ( ) ,
523+ Self :: Map ( s) => s. end ( ) ,
524+ }
525+ }
526+ }
527+
439528/// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
440529/// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, record, or big-decimal.
441530/// When serializing to a record, fields must be provided in the correct order, since no names are provided.
@@ -1499,7 +1588,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
14991588 & ' a mut self ,
15001589 len : Option < usize > ,
15011590 schema : & ' s Schema ,
1502- ) -> Result < SchemaAwareWriteSerializeMap < ' a , ' s , W > , Error > {
1591+ ) -> Result < SchemaAwareWriteSerializeMapOrStruct < ' a , ' s , W > , Error > {
15031592 let create_error = |cause : String | {
15041593 let len_str = len
15051594 . map ( |l| format ! ( "{l}" ) )
@@ -1513,10 +1602,8 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15131602 } ;
15141603
15151604 match schema {
1516- Schema :: Map ( map_schema) => Ok ( SchemaAwareWriteSerializeMap :: new (
1517- self ,
1518- map_schema. types . as_ref ( ) ,
1519- len,
1605+ Schema :: Map ( map_schema) => Ok ( SchemaAwareWriteSerializeMapOrStruct :: Map (
1606+ SchemaAwareWriteSerializeMap :: new ( self , map_schema. types . as_ref ( ) , len) ,
15201607 ) ) ,
15211608 Schema :: Union ( union_schema) => {
15221609 for ( i, variant_schema) in union_schema. schemas . iter ( ) . enumerate ( ) {
@@ -1532,6 +1619,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15321619 "Expected a Map schema in {union_schema:?}"
15331620 ) ) )
15341621 }
1622+ Schema :: Record ( record_schema) => Ok ( SchemaAwareWriteSerializeMapOrStruct :: Struct (
1623+ SchemaAwareWriteSerializeStruct :: new ( self , record_schema) ,
1624+ ) ) ,
15351625 _ => Err ( create_error ( format ! (
15361626 "Expected Map or Union schema. Got: {schema}"
15371627 ) ) ) ,
@@ -1630,7 +1720,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s
16301720 type SerializeTuple = SchemaAwareWriteSerializeSeq < ' a , ' s , W > ;
16311721 type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct < ' a , ' s , W > ;
16321722 type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct < ' a , ' s , W > ;
1633- type SerializeMap = SchemaAwareWriteSerializeMap < ' a , ' s , W > ;
1723+ type SerializeMap = SchemaAwareWriteSerializeMapOrStruct < ' a , ' s , W > ;
16341724 type SerializeStruct = SchemaAwareWriteSerializeStruct < ' a , ' s , W > ;
16351725 type SerializeStructVariant = SchemaAwareWriteSerializeStruct < ' a , ' s , W > ;
16361726
0 commit comments