Skip to content

Commit 8ceeae9

Browse files
committed
Support serde flatten annotation for schema derivation and de/serialization
Fixes #247
1 parent 940e52f commit 8ceeae9

File tree

4 files changed

+161
-46
lines changed

4 files changed

+161
-46
lines changed

avro/src/schema.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,21 @@ pub struct RecordSchema {
816816
pub attributes: BTreeMap<String, Value>,
817817
}
818818

819+
impl RecordSchema {
820+
const SERDE_FLATTEN_SUPPORT: &str = "avro_rs_serde_flatten_support";
821+
822+
pub fn has_serde_flatten_support(&self) -> bool {
823+
self.attributes
824+
.get(Self::SERDE_FLATTEN_SUPPORT)
825+
.is_some_and(|value| value.as_bool().unwrap_or(false))
826+
}
827+
828+
pub fn set_serde_flatten_support(&mut self) {
829+
self.attributes
830+
.insert(Self::SERDE_FLATTEN_SUPPORT.into(), true.into());
831+
}
832+
}
833+
819834
/// A description of an Enum schema.
820835
#[derive(bon::Builder, Debug, Clone)]
821836
pub struct EnumSchema {

avro/src/writer.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use crate::{
2323
headers::{HeaderBuilder, RabinFingerprintHeader},
2424
schema::{AvroSchema, Name, ResolvedOwnedSchema, ResolvedSchema, Schema},
2525
ser_schema::SchemaAwareWriteSerializer,
26-
types::Value,
26+
to_value,
27+
types::{Record, Value},
2728
};
2829
use serde::Serialize;
2930
use std::{
@@ -211,22 +212,36 @@ impl<'a, W: Write> Writer<'a, W> {
211212
let n = self.maybe_write_header()?;
212213

213214
match self.resolved_schema {
214-
Some(ref rs) => {
215-
let mut serializer = SchemaAwareWriteSerializer::new(
216-
&mut self.buffer,
217-
self.schema,
218-
rs.get_names(),
219-
None,
220-
);
221-
value.serialize(&mut serializer)?;
222-
self.num_values += 1;
223-
224-
if self.buffer.len() >= self.block_size {
225-
return self.flush().map(|b| b + n);
215+
Some(ref rs) => match self.schema {
216+
Schema::Record(record_schema) if record_schema.has_serde_flatten_support() => {
217+
match to_value(value)? {
218+
Value::Map(m) => {
219+
let mut record = Record::new(self.schema).unwrap();
220+
for (key, value) in m.into_iter() {
221+
record.put(&key, value)
222+
}
223+
self.append(record)
224+
}
225+
value => panic!("expected a map, got {value:?}"),
226+
}
226227
}
228+
_ => {
229+
let mut serializer = SchemaAwareWriteSerializer::new(
230+
&mut self.buffer,
231+
self.schema,
232+
rs.get_names(),
233+
None,
234+
);
235+
value.serialize(&mut serializer)?;
236+
self.num_values += 1;
237+
238+
if self.buffer.len() >= self.block_size {
239+
return self.flush().map(|b| b + n);
240+
}
227241

228-
Ok(n)
229-
}
242+
Ok(n)
243+
}
244+
},
230245
None => {
231246
let rs = ResolvedSchema::try_from(self.schema)?;
232247
self.resolved_schema = Some(rs);

avro_derive/src/lib.rs

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ fn get_data_struct_schema_def(
142142
let mut record_field_exprs = vec![];
143143
match s.fields {
144144
syn::Fields::Named(ref a) => {
145-
let mut index: usize = 0;
146145
for field in a.named.iter() {
147146
let mut name = field.ident.as_ref().unwrap().to_string(); // we know everything has a name
147+
let original_name = name.clone();
148148
if let Some(raw_name) = name.strip_prefix("r#") {
149149
name = raw_name.to_string();
150150
}
@@ -163,38 +163,64 @@ fn get_data_struct_schema_def(
163163
}
164164
if let Some(true) = field_attrs.skip {
165165
continue;
166-
}
167-
let default_value = match field_attrs.default {
168-
Some(default_value) => {
169-
let _: serde_json::Value = serde_json::from_str(&default_value[..])
170-
.map_err(|e| {
171-
vec![syn::Error::new(
172-
field.ident.span(),
173-
format!("Invalid avro default json: \n{e}"),
174-
)]
175-
})?;
176-
quote! {
177-
Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
178-
}
166+
} else if field.attrs.iter().any(|attr| {
167+
let mut flatten = false;
168+
if attr.path().is_ident("serde") {
169+
let _ = attr.parse_nested_meta(|meta| {
170+
if meta.path.is_ident("flatten") {
171+
flatten = true
172+
}
173+
Ok(())
174+
});
179175
}
180-
None => quote! { None },
181-
};
182-
let aliases = preserve_vec(field_attrs.alias);
183-
let schema_expr = type_to_schema_expr(&field.ty)?;
184-
let position = index;
185-
record_field_exprs.push(quote! {
186-
apache_avro::schema::RecordField {
176+
flatten
177+
}) {
178+
let ty = &field.ty;
179+
record_field_exprs.push(quote! {
180+
match #ty::get_schema() {
181+
apache_avro::Schema::Record(record) =>
182+
for mut field in record.fields {
183+
field.position = index;
184+
index += 1;
185+
schema_fields.push(field);
186+
},
187+
_ => panic!("Can not flatten field {}, only fields with a Record schema can be", #original_name)
188+
}
189+
set_serde_flatten_support = true;
190+
})
191+
} else {
192+
let default_value = match field_attrs.default {
193+
Some(default_value) => {
194+
let _: serde_json::Value = serde_json::from_str(&default_value[..])
195+
.map_err(|e| {
196+
vec![syn::Error::new(
197+
field.ident.span(),
198+
format!("Invalid avro default json: \n{e}"),
199+
)]
200+
})?;
201+
quote! {
202+
Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
203+
}
204+
}
205+
None => quote! { None },
206+
};
207+
let aliases = preserve_vec(field_attrs.alias);
208+
let schema_expr = type_to_schema_expr(&field.ty)?;
209+
record_field_exprs.push(quote! {
210+
let position = index;
211+
index += 1;
212+
schema_fields.push(apache_avro::schema::RecordField {
187213
name: #name.to_string(),
188214
doc: #doc,
189215
default: #default_value,
190216
aliases: #aliases,
191217
schema: #schema_expr,
192218
order: apache_avro::schema::RecordFieldOrder::Ascending,
193-
position: #position,
219+
position,
194220
custom_attributes: Default::default(),
195-
}
196-
});
197-
index += 1;
221+
});
222+
});
223+
}
198224
}
199225
}
200226
syn::Fields::Unnamed(_) => {
@@ -213,20 +239,27 @@ fn get_data_struct_schema_def(
213239
let record_doc = preserve_optional(record_doc);
214240
let record_aliases = preserve_vec(aliases);
215241
Ok(quote! {
216-
let schema_fields = vec![#(#record_field_exprs),*];
242+
let mut index = 0;
243+
let mut schema_fields = vec![];
244+
let mut set_serde_flatten_support = false;
245+
#(#record_field_exprs)*
217246
let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
218247
let lookup: std::collections::BTreeMap<String, usize> = schema_fields
219248
.iter()
220249
.map(|field| (field.name.to_owned(), field.position))
221250
.collect();
222-
apache_avro::schema::Schema::Record(apache_avro::schema::RecordSchema {
251+
let mut schema = apache_avro::schema::RecordSchema {
223252
name,
224253
aliases: #record_aliases,
225254
doc: #record_doc,
226255
fields: schema_fields,
227256
lookup,
228257
attributes: Default::default(),
229-
})
258+
};
259+
if set_serde_flatten_support {
260+
schema.set_serde_flatten_support();
261+
}
262+
apache_avro::schema::Schema::Record(schema)
230263
})
231264
}
232265

@@ -683,7 +716,7 @@ mod tests {
683716
match syn::parse2::<DeriveInput>(test_struct) {
684717
Ok(mut input) => {
685718
let schema_res = derive_avro_schema(&mut input);
686-
let expected_token_stream = r#"let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , }] ;"#;
719+
let expected_token_stream = r#"let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ;"#;
687720
let schema_token_stream = schema_res.unwrap().to_string();
688721
assert!(schema_token_stream.contains(expected_token_stream));
689722
}
@@ -725,7 +758,7 @@ mod tests {
725758
match syn::parse2::<DeriveInput>(test_struct) {
726759
Ok(mut input) => {
727760
let schema_res = derive_avro_schema(&mut input);
728-
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , } , apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 1usize , custom_attributes : Default :: default () , }] ;"#;
761+
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let mut index = 0 ; let mut schema_fields = vec ! [] ; let mut set_serde_flatten_support = false ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ;"#;
729762
let schema_token_stream = schema_res.unwrap().to_string();
730763
assert!(schema_token_stream.contains(expected_token_stream));
731764
}
@@ -769,7 +802,7 @@ mod tests {
769802
match syn::parse2::<DeriveInput>(test_struct) {
770803
Ok(mut input) => {
771804
let schema_res = derive_avro_schema(&mut input);
772-
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , } , apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 1usize , custom_attributes : Default :: default () , }] ;"#;
805+
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let mut index = 0 ; let mut schema_fields = vec ! [] ; let mut set_serde_flatten_support = false ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ;"#;
773806
let schema_token_stream = schema_res.unwrap().to_string();
774807
assert!(schema_token_stream.contains(expected_token_stream));
775808
}

avro_derive/tests/derive.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,4 +1676,56 @@ mod test_derive {
16761676
panic!("Unexpected schema type for Foo")
16771677
}
16781678
}
1679+
1680+
#[test]
1681+
fn avro_311_serde_flatten_support() {
1682+
#[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
1683+
struct Nested {
1684+
a: bool,
1685+
}
1686+
1687+
#[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
1688+
struct Foo {
1689+
#[serde(flatten)]
1690+
nested: Nested,
1691+
b: i32,
1692+
}
1693+
1694+
let schema = r#"
1695+
{
1696+
"type":"record",
1697+
"name":"Foo",
1698+
"fields": [
1699+
{
1700+
"name":"a1",
1701+
"type":"boolean"
1702+
},
1703+
{
1704+
"name":"b",
1705+
"type":"int"
1706+
}
1707+
]
1708+
}
1709+
"#;
1710+
1711+
let schema = Schema::parse_str(schema).unwrap();
1712+
let derived_schema = Foo::get_schema();
1713+
if let Schema::Record(RecordSchema { name, fields, .. }) = &derived_schema {
1714+
assert_eq!("Foo", name.fullname(None));
1715+
for field in fields {
1716+
match field.name.as_str() {
1717+
"a" | "b" => (), // expected
1718+
name => panic!("Unexpected field name '{name}'"),
1719+
}
1720+
}
1721+
} else {
1722+
panic!("Foo schema must be a record schema: {derived_schema:?}")
1723+
}
1724+
assert_eq!(schema, derived_schema);
1725+
1726+
serde_assert(Foo {
1727+
nested: Nested { a: true },
1728+
b: 321,
1729+
});
1730+
}
16791731
}

0 commit comments

Comments
 (0)