Skip to content

Commit 2019363

Browse files
Kriskras99default
andauthored
feat: Implement support for #[serde(flatten)] (#359)
* chore: Move all Serde related modules to the `serde` module * feat: Implement support for `#[serde(flatten)]` This is done by adding a `#[avro(flatten]` attribute so that the schema (for that field) is also flattened, and by adding support in `SchemaAwareWriteSerializer` for serializing a struct via Map instead of Struct. `flatten` does not work with `to_value`, as `to_value` does not have access to the schema. * fix: Handle duplicate fields in when flatten is used --------- Co-authored-by: default <admin@kriskras99.nl>
1 parent 13625ab commit 2019363

File tree

11 files changed

+650
-50
lines changed

11 files changed

+650
-50
lines changed

avro/src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,12 @@ pub enum Details {
579579

580580
#[error("Cannot convert a slice to Uuid: {0}")]
581581
UuidFromSlice(#[source] uuid::Error),
582+
583+
#[error("Expected String for Map key when serializing a flattened struct")]
584+
MapFieldExpectedString,
585+
586+
#[error("No key for value when serializing a map")]
587+
MapNoKey,
582588
}
583589

584590
#[derive(thiserror::Error, PartialEq)]

avro/src/lib.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -945,14 +945,12 @@
945945
mod bigdecimal;
946946
mod bytes;
947947
mod codec;
948-
mod de;
949948
mod decimal;
950949
mod decode;
951950
mod duration;
952951
mod encode;
953952
mod reader;
954-
mod ser;
955-
mod ser_schema;
953+
mod serde;
956954
mod writer;
957955

958956
pub mod error;
@@ -979,7 +977,6 @@ pub use codec::xz::XzSettings;
979977
#[cfg(feature = "zstandard")]
980978
pub use codec::zstandard::ZstandardSettings;
981979
pub use codec::{Codec, DeflateSettings};
982-
pub use de::from_value;
983980
pub use decimal::Decimal;
984981
pub use duration::{Days, Duration, Millis, Months};
985982
pub use error::Error;
@@ -988,7 +985,7 @@ pub use reader::{
988985
from_avro_datum_reader_schemata, from_avro_datum_schemata, read_marker,
989986
};
990987
pub use schema::{AvroSchema, Schema};
991-
pub use ser::to_value;
988+
pub use serde::{de::from_value, ser::to_value};
992989
pub use uuid::Uuid;
993990
pub use writer::{
994991
GenericSingleObjectWriter, SpecificSingleObjectWriter, Writer, WriterBuilder, to_avro_datum,
File renamed without changes.

avro/src/serde/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
pub mod de;
2+
pub mod ser;
3+
pub mod ser_schema;
4+
mod util;

avro/src/ser.rs renamed to avro/src/serde/ser.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> {
479479
/// Interpret a serializeable instance as a `Value`.
480480
///
481481
/// This conversion can fail if the value is not valid as per the Avro specification.
482-
/// e.g: HashMap with non-string keys
482+
/// e.g: `HashMap` with non-string keys.
483+
///
484+
/// This function does not work if `S` has any fields (recursively) that have the `#[serde(flatten)]`
485+
/// attribute. Please use [`Writer::append_ser`] if that's the case.
486+
///
487+
/// [`Writer::append_ser`]: crate::Writer::append_ser
483488
pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> {
484489
let mut serializer = Serializer::default();
485490
value.serialize(&mut serializer)
Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ use crate::{
2323
encode::{encode_int, encode_long},
2424
error::{Details, Error},
2525
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
26+
serde::util::StringSerializer,
2627
};
2728
use bigdecimal::BigDecimal;
28-
use serde::ser;
29+
use serde::{Serialize, ser};
2930
use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write, str::FromStr};
3031

3132
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -251,6 +252,8 @@ pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
251252
record_schema: &'s RecordSchema,
252253
/// Fields we received in the wrong order
253254
field_cache: HashMap<usize, Vec<u8>>,
255+
/// The current field name when serializing from a map (for `flatten` support).
256+
map_field_name: Option<String>,
254257
field_position: usize,
255258
bytes_written: usize,
256259
}
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
264267
ser,
265268
record_schema,
266269
field_cache: HashMap::new(),
270+
map_field_name: None,
267271
field_position: 0,
268272
bytes_written: 0,
269273
}
@@ -352,6 +356,11 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
352356
"There should be no more unwritten fields at this point: {:?}",
353357
self.field_cache
354358
);
359+
debug_assert!(
360+
self.map_field_name.is_none(),
361+
"There should be no field name at this point: field {:?}",
362+
self.map_field_name
363+
);
355364
Ok(self.bytes_written)
356365
}
357366
}
@@ -371,17 +380,14 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
371380
.and_then(|idx| self.record_schema.fields.get(*idx));
372381

373382
match record_field {
374-
Some(field) => {
375-
// self.item_count += 1;
376-
self.serialize_next_field(field, value).map_err(|e| {
377-
Details::SerializeRecordFieldWithSchema {
378-
field_name: key.to_string(),
379-
record_schema: Schema::Record(self.record_schema.clone()),
380-
error: Box::new(e),
381-
}
382-
.into()
383-
})
384-
}
383+
Some(field) => self.serialize_next_field(field, value).map_err(|e| {
384+
Details::SerializeRecordFieldWithSchema {
385+
field_name: key.to_string(),
386+
record_schema: Schema::Record(self.record_schema.clone()),
387+
error: Box::new(e),
388+
}
389+
.into()
390+
}),
385391
None => Err(Details::FieldName(String::from(key)).into()),
386392
}
387393
}
@@ -420,6 +426,53 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
420426
}
421427
}
422428

429+
/// This implementation is used to support `#[serde(flatten)]` as that uses SerializeMap instead of SerializeStruct.
430+
impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, W> {
431+
type Ok = usize;
432+
type Error = Error;
433+
434+
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
435+
where
436+
T: ?Sized + Serialize,
437+
{
438+
let name = key.serialize(StringSerializer)?;
439+
let old = self.map_field_name.replace(name);
440+
debug_assert!(
441+
old.is_none(),
442+
"Expected a value instead of a key: old key: {old:?}, new key: {:?}",
443+
self.map_field_name
444+
);
445+
Ok(())
446+
}
447+
448+
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
449+
where
450+
T: ?Sized + Serialize,
451+
{
452+
let key = self.map_field_name.take().ok_or(Details::MapNoKey)?;
453+
let record_field = self
454+
.record_schema
455+
.lookup
456+
.get(&key)
457+
.and_then(|idx| self.record_schema.fields.get(*idx));
458+
match record_field {
459+
Some(field) => self.serialize_next_field(field, value).map_err(|e| {
460+
Details::SerializeRecordFieldWithSchema {
461+
field_name: key.to_string(),
462+
record_schema: Schema::Record(self.record_schema.clone()),
463+
error: Box::new(e),
464+
}
465+
.into()
466+
}),
467+
None => Err(Details::FieldName(key).into()),
468+
}
469+
}
470+
471+
fn end(self) -> Result<Self::Ok, Self::Error> {
472+
self.end()
473+
}
474+
}
475+
423476
impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'_, '_, W> {
424477
type Ok = usize;
425478
type Error = Error;
@@ -436,6 +489,46 @@ impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'
436489
}
437490
}
438491

492+
/// Map serializer that switches between Struct or Map.
493+
///
494+
/// This exists because when `#[serde(flatten)]` is used, struct fields are serialized as a map.
495+
pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> {
496+
Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>),
497+
Map(SchemaAwareWriteSerializeMap<'a, 's, W>),
498+
}
499+
500+
impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, '_, W> {
501+
type Ok = usize;
502+
type Error = Error;
503+
504+
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
505+
where
506+
T: ?Sized + Serialize,
507+
{
508+
match self {
509+
Self::Struct(s) => s.serialize_key(key),
510+
Self::Map(s) => s.serialize_key(key),
511+
}
512+
}
513+
514+
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
515+
where
516+
T: ?Sized + Serialize,
517+
{
518+
match self {
519+
Self::Struct(s) => s.serialize_value(value),
520+
Self::Map(s) => s.serialize_value(value),
521+
}
522+
}
523+
524+
fn end(self) -> Result<Self::Ok, Self::Error> {
525+
match self {
526+
Self::Struct(s) => s.end(),
527+
Self::Map(s) => s.end(),
528+
}
529+
}
530+
}
531+
439532
/// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
440533
/// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, record, or big-decimal.
441534
/// When serializing to a record, fields must be provided in the correct order, since no names are provided.
@@ -1499,7 +1592,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
14991592
&'a mut self,
15001593
len: Option<usize>,
15011594
schema: &'s Schema,
1502-
) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> {
1595+
) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> {
15031596
let create_error = |cause: String| {
15041597
let len_str = len
15051598
.map(|l| format!("{l}"))
@@ -1513,15 +1606,17 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15131606
};
15141607

15151608
match schema {
1516-
Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new(
1517-
self,
1518-
map_schema.types.as_ref(),
1519-
len,
1609+
Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Map(
1610+
SchemaAwareWriteSerializeMap::new(self, map_schema.types.as_ref(), len),
15201611
)),
1612+
Schema::Ref { name: ref_name } => {
1613+
let ref_schema = self.get_ref_schema(ref_name)?;
1614+
self.serialize_map_with_schema(len, ref_schema)
1615+
}
15211616
Schema::Union(union_schema) => {
15221617
for (i, variant_schema) in union_schema.schemas.iter().enumerate() {
15231618
match variant_schema {
1524-
Schema::Map(_) => {
1619+
Schema::Map(_) | Schema::Record(_) | Schema::Ref { .. } => {
15251620
encode_int(i as i32, &mut *self.writer)?;
15261621
return self.serialize_map_with_schema(len, variant_schema);
15271622
}
@@ -1532,6 +1627,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15321627
"Expected a Map schema in {union_schema:?}"
15331628
)))
15341629
}
1630+
Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Struct(
1631+
SchemaAwareWriteSerializeStruct::new(self, record_schema),
1632+
)),
15351633
_ => Err(create_error(format!(
15361634
"Expected Map or Union schema. Got: {schema}"
15371635
))),
@@ -1630,7 +1728,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s
16301728
type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>;
16311729
type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>;
16321730
type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>;
1633-
type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>;
1731+
type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>;
16341732
type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>;
16351733
type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>;
16361734

0 commit comments

Comments
 (0)