Skip to content

Commit f176b28

Browse files
committed
feat: Implement support for #[serde(flatten)]
This is done by adding a `#[avro(flatten]` attribute so that the schema (for that field) is also flattened, and by adding support in `SchemaAwareWriteSerializer` for serializing a struct via Map instead of Struct. `flatten` does not work with `to_value`, as `to_value` does not have access to the schema.
1 parent 1155f95 commit f176b28

7 files changed

Lines changed: 560 additions & 40 deletions

File tree

avro/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,9 @@ pub enum Details {
579579

580580
#[error("Cannot convert a slice to Uuid: {0}")]
581581
UuidFromSlice(#[source] uuid::Error),
582+
583+
#[error("Expected String for Map key")]
584+
MapFieldExpectedString,
582585
}
583586

584587
#[derive(thiserror::Error, PartialEq)]

avro/src/serde/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod de;
22
pub mod ser;
33
pub mod ser_schema;
4+
mod util;

avro/src/serde/ser.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> {
479479
/// Interpret a serializeable instance as a `Value`.
480480
///
481481
/// This conversion can fail if the value is not valid as per the Avro specification.
482-
/// e.g: HashMap with non-string keys
482+
/// e.g: `HashMap` with non-string keys.
483+
///
484+
/// This function does not work if `S` has any fields (recursively) that have the `#[serde(flatten)]`
485+
/// attribute. Please use [`Writer::append_ser`] if that's the case.
486+
///
487+
/// [`Writer::append_ser`]: crate::Writer::append_ser
483488
pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> {
484489
let mut serializer = Serializer::default();
485490
value.serialize(&mut serializer)

avro/src/serde/ser_schema.rs

Lines changed: 108 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ use crate::{
2323
encode::{encode_int, encode_long},
2424
error::{Details, Error},
2525
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
26+
serde::util::StringSerializer,
2627
};
2728
use bigdecimal::BigDecimal;
28-
use serde::ser;
29+
use serde::{Serialize, ser};
2930
use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write, str::FromStr};
3031

3132
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -251,6 +252,8 @@ pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
251252
record_schema: &'s RecordSchema,
252253
/// Fields we received in the wrong order
253254
field_cache: HashMap<usize, Vec<u8>>,
255+
/// The current field name when serializing from a map (for `flatten` support).
256+
map_field_name: Option<String>,
254257
field_position: usize,
255258
bytes_written: usize,
256259
}
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
264267
ser,
265268
record_schema,
266269
field_cache: HashMap::new(),
270+
map_field_name: None,
267271
field_position: 0,
268272
bytes_written: 0,
269273
}
@@ -352,6 +356,10 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
352356
"There should be no more unwritten fields at this point: {:?}",
353357
self.field_cache
354358
);
359+
assert!(
360+
self.map_field_name.is_none(),
361+
"There should be no field name at this point"
362+
);
355363
Ok(self.bytes_written)
356364
}
357365
}
@@ -371,17 +379,14 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
371379
.and_then(|idx| self.record_schema.fields.get(*idx));
372380

373381
match record_field {
374-
Some(field) => {
375-
// self.item_count += 1;
376-
self.serialize_next_field(field, value).map_err(|e| {
377-
Details::SerializeRecordFieldWithSchema {
378-
field_name: key.to_string(),
379-
record_schema: Schema::Record(self.record_schema.clone()),
380-
error: Box::new(e),
381-
}
382-
.into()
383-
})
384-
}
382+
Some(field) => self.serialize_next_field(field, value).map_err(|e| {
383+
Details::SerializeRecordFieldWithSchema {
384+
field_name: key.to_string(),
385+
record_schema: Schema::Record(self.record_schema.clone()),
386+
error: Box::new(e),
387+
}
388+
.into()
389+
}),
385390
None => Err(Details::FieldName(String::from(key)).into()),
386391
}
387392
}
@@ -420,6 +425,50 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
420425
}
421426
}
422427

428+
impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, W> {
429+
type Ok = usize;
430+
type Error = Error;
431+
432+
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
433+
where
434+
T: ?Sized + Serialize,
435+
{
436+
let name = key.serialize(StringSerializer)?;
437+
assert!(
438+
self.map_field_name.replace(name).is_none(),
439+
"Got two keys in a row"
440+
);
441+
Ok(())
442+
}
443+
444+
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
445+
where
446+
T: ?Sized + Serialize,
447+
{
448+
let key = self.map_field_name.take().expect("Got value without key");
449+
let record_field = self
450+
.record_schema
451+
.lookup
452+
.get(&key)
453+
.and_then(|idx| self.record_schema.fields.get(*idx));
454+
match record_field {
455+
Some(field) => self.serialize_next_field(field, value).map_err(|e| {
456+
Details::SerializeRecordFieldWithSchema {
457+
field_name: key.to_string(),
458+
record_schema: Schema::Record(self.record_schema.clone()),
459+
error: Box::new(e),
460+
}
461+
.into()
462+
}),
463+
None => Err(Details::FieldName(key).into()),
464+
}
465+
}
466+
467+
fn end(self) -> Result<Self::Ok, Self::Error> {
468+
self.end()
469+
}
470+
}
471+
423472
impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'_, '_, W> {
424473
type Ok = usize;
425474
type Error = Error;
@@ -436,6 +485,46 @@ impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'
436485
}
437486
}
438487

488+
/// Map serializer that switches between Struct or Map.
489+
///
490+
/// This exists because when `#[serde(flatten)]` is used, struct fields are serialized as a map.
491+
pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> {
492+
Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>),
493+
Map(SchemaAwareWriteSerializeMap<'a, 's, W>),
494+
}
495+
496+
impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, '_, W> {
497+
type Ok = usize;
498+
type Error = Error;
499+
500+
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
501+
where
502+
T: ?Sized + Serialize,
503+
{
504+
match self {
505+
Self::Struct(s) => s.serialize_key(key),
506+
Self::Map(s) => s.serialize_key(key),
507+
}
508+
}
509+
510+
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
511+
where
512+
T: ?Sized + Serialize,
513+
{
514+
match self {
515+
Self::Struct(s) => s.serialize_value(value),
516+
Self::Map(s) => s.serialize_value(value),
517+
}
518+
}
519+
520+
fn end(self) -> Result<Self::Ok, Self::Error> {
521+
match self {
522+
Self::Struct(s) => s.end(),
523+
Self::Map(s) => s.end(),
524+
}
525+
}
526+
}
527+
439528
/// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
440529
/// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, record, or big-decimal.
441530
/// When serializing to a record, fields must be provided in the correct order, since no names are provided.
@@ -1499,7 +1588,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
14991588
&'a mut self,
15001589
len: Option<usize>,
15011590
schema: &'s Schema,
1502-
) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> {
1591+
) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> {
15031592
let create_error = |cause: String| {
15041593
let len_str = len
15051594
.map(|l| format!("{l}"))
@@ -1513,10 +1602,8 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15131602
};
15141603

15151604
match schema {
1516-
Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new(
1517-
self,
1518-
map_schema.types.as_ref(),
1519-
len,
1605+
Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Map(
1606+
SchemaAwareWriteSerializeMap::new(self, map_schema.types.as_ref(), len),
15201607
)),
15211608
Schema::Union(union_schema) => {
15221609
for (i, variant_schema) in union_schema.schemas.iter().enumerate() {
@@ -1532,6 +1619,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15321619
"Expected a Map schema in {union_schema:?}"
15331620
)))
15341621
}
1622+
Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Struct(
1623+
SchemaAwareWriteSerializeStruct::new(self, record_schema),
1624+
)),
15351625
_ => Err(create_error(format!(
15361626
"Expected Map or Union schema. Got: {schema}"
15371627
))),
@@ -1630,7 +1720,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s
16301720
type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>;
16311721
type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>;
16321722
type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>;
1633-
type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>;
1723+
type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>;
16341724
type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>;
16351725
type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>;
16361726

0 commit comments

Comments
 (0)