Skip to content

Commit 319bf1d

Browse files
committed
Support serde flatten annotation for schema derivation and de/serialization
Fixes #247
1 parent 940e52f commit 319bf1d

File tree

4 files changed

+151
-45
lines changed

4 files changed

+151
-45
lines changed

avro/src/schema.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,6 +2456,9 @@ pub mod derive {
24562456
use super::*;
24572457
use std::borrow::Cow;
24582458

2459+
pub use crate::writer::SERDE_FLATTEN;
2460+
pub use serde_json;
2461+
24592462
/// Trait for types that serve as fully defined components inside an Avro data model. Derive
24602463
/// implementation available through `derive` feature. This is what is implemented by
24612464
/// the `derive(AvroSchema)` macro.

avro/src/writer.rs

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use crate::{
2323
headers::{HeaderBuilder, RabinFingerprintHeader},
2424
schema::{AvroSchema, Name, ResolvedOwnedSchema, ResolvedSchema, Schema},
2525
ser_schema::SchemaAwareWriteSerializer,
26-
types::Value,
26+
to_value,
27+
types::{Record, Value},
2728
};
2829
use serde::Serialize;
2930
use std::{
@@ -33,6 +34,8 @@ use std::{
3334
const DEFAULT_BLOCK_SIZE: usize = 16000;
3435
const AVRO_OBJECT_HEADER: &[u8] = b"Obj\x01";
3536

37+
pub const SERDE_FLATTEN: &str = "serde-flatten";
38+
3639
/// Main interface for writing Avro formatted values.
3740
///
3841
/// It is critical to call flush before `Writer<W>` is dropped. Though dropping will attempt to flush
@@ -211,22 +214,41 @@ impl<'a, W: Write> Writer<'a, W> {
211214
let n = self.maybe_write_header()?;
212215

213216
match self.resolved_schema {
214-
Some(ref rs) => {
215-
let mut serializer = SchemaAwareWriteSerializer::new(
216-
&mut self.buffer,
217-
self.schema,
218-
rs.get_names(),
219-
None,
220-
);
221-
value.serialize(&mut serializer)?;
222-
self.num_values += 1;
223-
224-
if self.buffer.len() >= self.block_size {
225-
return self.flush().map(|b| b + n);
217+
Some(ref rs) => match self.schema {
218+
Schema::Record(record_schema)
219+
if record_schema
220+
.attributes
221+
.get(SERDE_FLATTEN)
222+
.is_some_and(|value| value.as_bool().unwrap_or(false)) =>
223+
{
224+
match to_value(value)? {
225+
Value::Map(m) => {
226+
let mut record = Record::new(self.schema).unwrap();
227+
for (key, value) in m.into_iter() {
228+
record.put(&key, value)
229+
}
230+
self.append(record)
231+
}
232+
value => panic!("expected a map, got {value:?}"),
233+
}
226234
}
235+
_ => {
236+
let mut serializer = SchemaAwareWriteSerializer::new(
237+
&mut self.buffer,
238+
self.schema,
239+
rs.get_names(),
240+
None,
241+
);
242+
value.serialize(&mut serializer)?;
243+
self.num_values += 1;
244+
245+
if self.buffer.len() >= self.block_size {
246+
return self.flush().map(|b| b + n);
247+
}
227248

228-
Ok(n)
229-
}
249+
Ok(n)
250+
}
251+
},
230252
None => {
231253
let rs = ResolvedSchema::try_from(self.schema)?;
232254
self.resolved_schema = Some(rs);

avro_derive/src/lib.rs

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ fn get_data_struct_schema_def(
142142
let mut record_field_exprs = vec![];
143143
match s.fields {
144144
syn::Fields::Named(ref a) => {
145-
let mut index: usize = 0;
146145
for field in a.named.iter() {
147146
let mut name = field.ident.as_ref().unwrap().to_string(); // we know everything has a name
147+
let original_name = name.clone();
148148
if let Some(raw_name) = name.strip_prefix("r#") {
149149
name = raw_name.to_string();
150150
}
@@ -163,38 +163,64 @@ fn get_data_struct_schema_def(
163163
}
164164
if let Some(true) = field_attrs.skip {
165165
continue;
166-
}
167-
let default_value = match field_attrs.default {
168-
Some(default_value) => {
169-
let _: serde_json::Value = serde_json::from_str(&default_value[..])
170-
.map_err(|e| {
171-
vec![syn::Error::new(
172-
field.ident.span(),
173-
format!("Invalid avro default json: \n{e}"),
174-
)]
175-
})?;
176-
quote! {
177-
Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
178-
}
166+
} else if field.attrs.iter().any(|attr| {
167+
let mut flatten = false;
168+
if attr.path().is_ident("serde") {
169+
let _ = attr.parse_nested_meta(|meta| {
170+
if meta.path.is_ident("flatten") {
171+
flatten = true
172+
}
173+
Ok(())
174+
});
179175
}
180-
None => quote! { None },
181-
};
182-
let aliases = preserve_vec(field_attrs.alias);
183-
let schema_expr = type_to_schema_expr(&field.ty)?;
184-
let position = index;
185-
record_field_exprs.push(quote! {
186-
apache_avro::schema::RecordField {
176+
flatten
177+
}) {
178+
let ty = &field.ty;
179+
record_field_exprs.push(quote! {
180+
match #ty::get_schema() {
181+
apache_avro::Schema::Record(record) =>
182+
for mut field in record.fields {
183+
field.position = index;
184+
index += 1;
185+
schema_fields.push(field);
186+
},
187+
_ => panic!("Can not flatten field {}, only fields with a Record schema can be", #original_name)
188+
}
189+
attributes.insert(apache_avro::schema::derive::SERDE_FLATTEN.to_string(), true.into());
190+
})
191+
} else {
192+
let default_value = match field_attrs.default {
193+
Some(default_value) => {
194+
let _: serde_json::Value = serde_json::from_str(&default_value[..])
195+
.map_err(|e| {
196+
vec![syn::Error::new(
197+
field.ident.span(),
198+
format!("Invalid avro default json: \n{e}"),
199+
)]
200+
})?;
201+
quote! {
202+
Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
203+
}
204+
}
205+
None => quote! { None },
206+
};
207+
let aliases = preserve_vec(field_attrs.alias);
208+
let schema_expr = type_to_schema_expr(&field.ty)?;
209+
record_field_exprs.push(quote! {
210+
let position = index;
211+
index += 1;
212+
schema_fields.push(apache_avro::schema::RecordField {
187213
name: #name.to_string(),
188214
doc: #doc,
189215
default: #default_value,
190216
aliases: #aliases,
191217
schema: #schema_expr,
192218
order: apache_avro::schema::RecordFieldOrder::Ascending,
193-
position: #position,
219+
position,
194220
custom_attributes: Default::default(),
195-
}
196-
});
197-
index += 1;
221+
});
222+
});
223+
}
198224
}
199225
}
200226
syn::Fields::Unnamed(_) => {
@@ -213,7 +239,10 @@ fn get_data_struct_schema_def(
213239
let record_doc = preserve_optional(record_doc);
214240
let record_aliases = preserve_vec(aliases);
215241
Ok(quote! {
216-
let schema_fields = vec![#(#record_field_exprs),*];
242+
let mut index = 0;
243+
let mut schema_fields = vec![];
244+
let mut attributes: std::collections::BTreeMap<String, apache_avro::schema::derive::serde_json::Value> = Default::default();
245+
#(#record_field_exprs)*
217246
let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
218247
let lookup: std::collections::BTreeMap<String, usize> = schema_fields
219248
.iter()
@@ -225,7 +254,7 @@ fn get_data_struct_schema_def(
225254
doc: #record_doc,
226255
fields: schema_fields,
227256
lookup,
228-
attributes: Default::default(),
257+
attributes,
229258
})
230259
})
231260
}
@@ -683,7 +712,7 @@ mod tests {
683712
match syn::parse2::<DeriveInput>(test_struct) {
684713
Ok(mut input) => {
685714
let schema_res = derive_avro_schema(&mut input);
686-
let expected_token_stream = r#"let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , }] ;"#;
715+
let expected_token_stream = r#"let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ;"#;
687716
let schema_token_stream = schema_res.unwrap().to_string();
688717
assert!(schema_token_stream.contains(expected_token_stream));
689718
}
@@ -725,7 +754,7 @@ mod tests {
725754
match syn::parse2::<DeriveInput>(test_struct) {
726755
Ok(mut input) => {
727756
let schema_res = derive_avro_schema(&mut input);
728-
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , } , apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 1usize , custom_attributes : Default :: default () , }] ;"#;
757+
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let mut index = 0 ; let mut schema_fields = vec ! [] ; let mut attributes : std :: collections :: BTreeMap < String , apache_avro :: schema :: derive :: serde_json :: Value > = Default :: default () ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ;"#;
729758
let schema_token_stream = schema_res.unwrap().to_string();
730759
assert!(schema_token_stream.contains(expected_token_stream));
731760
}
@@ -769,7 +798,7 @@ mod tests {
769798
match syn::parse2::<DeriveInput>(test_struct) {
770799
Ok(mut input) => {
771800
let schema_res = derive_avro_schema(&mut input);
772-
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , } , apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 1usize , custom_attributes : Default :: default () , }] ;"#;
801+
let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let mut index = 0 ; let mut schema_fields = vec ! [] ; let mut attributes : std :: collections :: BTreeMap < String , apache_avro :: schema :: derive :: serde_json :: Value > = Default :: default () ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ; let position = index ; index += 1 ; schema_fields . push (apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position , custom_attributes : Default :: default () , }) ;"#;
773802
let schema_token_stream = schema_res.unwrap().to_string();
774803
assert!(schema_token_stream.contains(expected_token_stream));
775804
}

avro_derive/tests/derive.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,4 +1676,56 @@ mod test_derive {
16761676
panic!("Unexpected schema type for Foo")
16771677
}
16781678
}
1679+
1680+
#[test]
1681+
fn test_serde_flatten() {
1682+
#[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
1683+
struct Nested {
1684+
a: bool,
1685+
}
1686+
1687+
#[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
1688+
struct Foo {
1689+
#[serde(flatten)]
1690+
nested: Nested,
1691+
b: i32,
1692+
}
1693+
1694+
let schema = r#"
1695+
{
1696+
"type":"record",
1697+
"name":"Foo",
1698+
"fields": [
1699+
{
1700+
"name":"a1",
1701+
"type":"boolean"
1702+
},
1703+
{
1704+
"name":"b",
1705+
"type":"int"
1706+
}
1707+
]
1708+
}
1709+
"#;
1710+
1711+
let schema = Schema::parse_str(schema).unwrap();
1712+
let derived_schema = Foo::get_schema();
1713+
if let Schema::Record(RecordSchema { name, fields, .. }) = &derived_schema {
1714+
assert_eq!("Foo", name.fullname(None));
1715+
for field in fields {
1716+
match field.name.as_str() {
1717+
"a" | "b" => (), // expected
1718+
name => panic!("Unexpected field name '{name}'"),
1719+
}
1720+
}
1721+
} else {
1722+
panic!("Foo schema must be a record schema: {derived_schema:?}")
1723+
}
1724+
assert_eq!(schema, derived_schema);
1725+
1726+
serde_assert(Foo {
1727+
nested: Nested { a: true },
1728+
b: 321,
1729+
});
1730+
}
16791731
}

0 commit comments

Comments
 (0)