Skip to content

Commit 7e403fa

Browse files
committed
feat: Implement support for #[serde(flatten)]
This is done by adding a `#[avro(flatten]` attribute so that the schema (for that field) is also flattened, and by adding support in `SchemaAwareWriteSerializer` for serializing a struct via Map instead of Struct. `flatten` does not work with `to_value`, as `to_value` does not have access to the schema.
1 parent b4924c0 commit 7e403fa

File tree

7 files changed

+561
-41
lines changed

7 files changed

+561
-41
lines changed

avro/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,9 @@ pub enum Details {
579579

580580
#[error("Cannot convert a slice to Uuid: {0}")]
581581
UuidFromSlice(#[source] uuid::Error),
582+
583+
#[error("Expected String for Map key")]
584+
MapFieldExpectedString,
582585
}
583586

584587
#[derive(thiserror::Error, PartialEq)]

avro/src/serde/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod de;
22
pub mod ser;
33
pub mod ser_schema;
4+
mod util;

avro/src/serde/ser.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> {
479479
/// Interpret a serializeable instance as a `Value`.
480480
///
481481
/// This conversion can fail if the value is not valid as per the Avro specification.
482-
/// e.g: HashMap with non-string keys
482+
/// e.g: `HashMap` with non-string keys.
483+
///
484+
/// This function does not work if `S` has any fields (recursively) that have the `#[serde(flatten)]`
485+
/// attribute. Please use [`Writer::append_ser`] if that's the case.
486+
///
487+
/// [`Writer::append_ser`]: crate::Writer::append_ser
483488
pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> {
484489
let mut serializer = Serializer::default();
485490
value.serialize(&mut serializer)

avro/src/serde/ser_schema.rs

Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ use crate::{
2323
encode::{encode_int, encode_long},
2424
error::{Details, Error},
2525
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
26+
serde::util::StringSerializer,
2627
};
2728
use bigdecimal::BigDecimal;
28-
use serde::ser;
29+
use serde::{Serialize, ser};
2930
use std::{borrow::Cow, io::Write, str::FromStr};
3031

3132
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -249,8 +250,10 @@ impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMap<'_, '_, W> {
249250
pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
250251
ser: &'a mut SchemaAwareWriteSerializer<'s, W>,
251252
record_schema: &'s RecordSchema,
252-
/// Fields we received in the wrong order
253+
/// Fields we received in the wrong order.
253254
field_cache: Vec<(usize, Vec<u8>)>,
255+
/// The current field name when serializing from a map (for `flatten` support).
256+
map_field_name: Option<String>,
254257
next_field: usize,
255258
bytes_written: usize,
256259
}
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
264267
ser,
265268
record_schema,
266269
field_cache: Vec::new(),
270+
map_field_name: None,
267271
next_field: 0,
268272
bytes_written: 0,
269273
}
@@ -353,6 +357,10 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
353357
self.field_cache.is_empty(),
354358
"There should be no more unwritten fields at this point"
355359
);
360+
assert!(
361+
self.map_field_name.is_none(),
362+
"There should be no field name at this point"
363+
);
356364
Ok(self.bytes_written)
357365
}
358366
}
@@ -372,17 +380,14 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
372380
.and_then(|idx| self.record_schema.fields.get(*idx));
373381

374382
match record_field {
375-
Some(field) => {
376-
// self.item_count += 1;
377-
self.serialize_next_field(field, value).map_err(|e| {
378-
Details::SerializeRecordFieldWithSchema {
379-
field_name: key.to_string(),
380-
record_schema: Schema::Record(self.record_schema.clone()),
381-
error: Box::new(e),
382-
}
383-
.into()
384-
})
385-
}
383+
Some(field) => self.serialize_next_field(field, value).map_err(|e| {
384+
Details::SerializeRecordFieldWithSchema {
385+
field_name: key.to_string(),
386+
record_schema: Schema::Record(self.record_schema.clone()),
387+
error: Box::new(e),
388+
}
389+
.into()
390+
}),
386391
None => Err(Details::FieldName(String::from(key)).into()),
387392
}
388393
}
@@ -421,6 +426,50 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
421426
}
422427
}
423428

429+
impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, W> {
430+
type Ok = usize;
431+
type Error = Error;
432+
433+
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
434+
where
435+
T: ?Sized + Serialize,
436+
{
437+
let name = key.serialize(StringSerializer)?;
438+
assert!(
439+
self.map_field_name.replace(name).is_none(),
440+
"Got two keys in a row"
441+
);
442+
Ok(())
443+
}
444+
445+
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
446+
where
447+
T: ?Sized + Serialize,
448+
{
449+
let key = self.map_field_name.take().expect("Got value without key");
450+
let record_field = self
451+
.record_schema
452+
.lookup
453+
.get(&key)
454+
.and_then(|idx| self.record_schema.fields.get(*idx));
455+
match record_field {
456+
Some(field) => self.serialize_next_field(field, value).map_err(|e| {
457+
Details::SerializeRecordFieldWithSchema {
458+
field_name: key.to_string(),
459+
record_schema: Schema::Record(self.record_schema.clone()),
460+
error: Box::new(e),
461+
}
462+
.into()
463+
}),
464+
None => Err(Details::FieldName(key).into()),
465+
}
466+
}
467+
468+
fn end(self) -> Result<Self::Ok, Self::Error> {
469+
self.end()
470+
}
471+
}
472+
424473
impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'_, '_, W> {
425474
type Ok = usize;
426475
type Error = Error;
@@ -437,6 +486,46 @@ impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'
437486
}
438487
}
439488

489+
/// Map serializer that switches between Struct or Map.
490+
///
491+
/// This exists because when `#[serde(flatten)]` is used, struct fields are serialized as a map.
492+
pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> {
493+
Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>),
494+
Map(SchemaAwareWriteSerializeMap<'a, 's, W>),
495+
}
496+
497+
impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, '_, W> {
498+
type Ok = usize;
499+
type Error = Error;
500+
501+
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
502+
where
503+
T: ?Sized + Serialize,
504+
{
505+
match self {
506+
Self::Struct(s) => s.serialize_key(key),
507+
Self::Map(s) => s.serialize_key(key),
508+
}
509+
}
510+
511+
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
512+
where
513+
T: ?Sized + Serialize,
514+
{
515+
match self {
516+
Self::Struct(s) => s.serialize_value(value),
517+
Self::Map(s) => s.serialize_value(value),
518+
}
519+
}
520+
521+
fn end(self) -> Result<Self::Ok, Self::Error> {
522+
match self {
523+
Self::Struct(s) => s.end(),
524+
Self::Map(s) => s.end(),
525+
}
526+
}
527+
}
528+
440529
/// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
441530
/// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, record, or big-decimal.
442531
/// When serializing to a record, fields must be provided in the correct order, since no names are provided.
@@ -1500,7 +1589,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15001589
&'a mut self,
15011590
len: Option<usize>,
15021591
schema: &'s Schema,
1503-
) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> {
1592+
) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> {
15041593
let create_error = |cause: String| {
15051594
let len_str = len
15061595
.map(|l| format!("{l}"))
@@ -1514,10 +1603,8 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15141603
};
15151604

15161605
match schema {
1517-
Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new(
1518-
self,
1519-
map_schema.types.as_ref(),
1520-
len,
1606+
Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Map(
1607+
SchemaAwareWriteSerializeMap::new(self, map_schema.types.as_ref(), len),
15211608
)),
15221609
Schema::Union(union_schema) => {
15231610
for (i, variant_schema) in union_schema.schemas.iter().enumerate() {
@@ -1533,6 +1620,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
15331620
"Expected a Map schema in {union_schema:?}"
15341621
)))
15351622
}
1623+
Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Struct(
1624+
SchemaAwareWriteSerializeStruct::new(self, record_schema),
1625+
)),
15361626
_ => Err(create_error(format!(
15371627
"Expected Map or Union schema. Got: {schema}"
15381628
))),
@@ -1631,7 +1721,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s
16311721
type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>;
16321722
type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>;
16331723
type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>;
1634-
type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>;
1724+
type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>;
16351725
type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>;
16361726
type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>;
16371727

0 commit comments

Comments
 (0)