Skip to content

Commit 8243090

Browse files
committed
fix: Union serialization
1 parent 2019363 commit 8243090

4 files changed

Lines changed: 262 additions & 103 deletions

File tree

avro/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ pub enum Details {
274274
#[error("Could not find matching type in {schema:?} for {value:?}")]
275275
FindUnionVariant { schema: UnionSchema, value: Value },
276276

277+
#[error("Union index {index} out of bounds: {num_variants} in {schema:?} for {value:?}")]
278+
UnionIndexOutOfBounds { schema: UnionSchema, value: Value, index: usize, num_variants: usize },
279+
277280
#[error("Union type should not be empty")]
278281
EmptyUnion,
279282

avro/src/serde/ser.rs

Lines changed: 38 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ pub struct SeqSerializer {
3131
items: Vec<Value>,
3232
}
3333

34-
pub struct SeqVariantSerializer<'a> {
34+
pub struct SeqVariantSerializer {
3535
index: u32,
36-
variant: &'a str,
3736
items: Vec<Value>,
3837
}
3938

@@ -46,9 +45,8 @@ pub struct StructSerializer {
4645
fields: Vec<(String, Value)>,
4746
}
4847

49-
pub struct StructVariantSerializer<'a> {
48+
pub struct StructVariantSerializer {
5049
index: u32,
51-
variant: &'a str,
5250
fields: Vec<(String, Value)>,
5351
}
5452

@@ -63,15 +61,14 @@ impl SeqSerializer {
6361
}
6462
}
6563

66-
impl<'a> SeqVariantSerializer<'a> {
67-
pub fn new(index: u32, variant: &'a str, len: Option<usize>) -> SeqVariantSerializer<'a> {
64+
impl SeqVariantSerializer {
65+
pub fn new(index: u32, len: Option<usize>) -> SeqVariantSerializer {
6866
let items = match len {
6967
Some(len) => Vec::with_capacity(len),
7068
None => Vec::new(),
7169
};
7270
SeqVariantSerializer {
7371
index,
74-
variant,
7572
items,
7673
}
7774
}
@@ -96,11 +93,10 @@ impl StructSerializer {
9693
}
9794
}
9895

99-
impl<'a> StructVariantSerializer<'a> {
100-
pub fn new(index: u32, variant: &'a str, len: usize) -> StructVariantSerializer<'a> {
96+
impl StructVariantSerializer {
97+
pub fn new(index: u32, len: usize) -> StructVariantSerializer {
10198
StructVariantSerializer {
10299
index,
103-
variant,
104100
fields: Vec::with_capacity(len),
105101
}
106102
}
@@ -112,10 +108,10 @@ impl<'b> ser::Serializer for &'b mut Serializer {
112108
type SerializeSeq = SeqSerializer;
113109
type SerializeTuple = SeqSerializer;
114110
type SerializeTupleStruct = SeqSerializer;
115-
type SerializeTupleVariant = SeqVariantSerializer<'b>;
111+
type SerializeTupleVariant = SeqVariantSerializer;
116112
type SerializeMap = MapSerializer;
117113
type SerializeStruct = StructSerializer;
118-
type SerializeStructVariant = StructVariantSerializer<'b>;
114+
type SerializeStructVariant = StructVariantSerializer;
119115

120116
fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
121117
Ok(Value::Boolean(v))
@@ -226,21 +222,15 @@ impl<'b> ser::Serializer for &'b mut Serializer {
226222

227223
fn serialize_newtype_variant<T>(
228224
self,
229-
_: &'static str,
225+
_name: &'static str,
230226
index: u32,
231-
variant: &'static str,
227+
_variant: &'static str,
232228
value: &T,
233229
) -> Result<Self::Ok, Self::Error>
234230
where
235231
T: Serialize + ?Sized,
236232
{
237-
Ok(Value::Record(vec![
238-
("type".to_owned(), Value::Enum(index, variant.to_owned())),
239-
(
240-
"value".to_owned(),
241-
Value::Union(index, Box::new(value.serialize(self)?)),
242-
),
243-
]))
233+
Ok(Value::Union(index, Box::new(value.serialize(self)?)))
244234
}
245235

246236
fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
@@ -261,12 +251,12 @@ impl<'b> ser::Serializer for &'b mut Serializer {
261251

262252
fn serialize_tuple_variant(
263253
self,
264-
_: &'static str,
254+
_name: &'static str,
265255
index: u32,
266-
variant: &'static str,
256+
_variant: &'static str,
267257
len: usize,
268258
) -> Result<Self::SerializeTupleVariant, Self::Error> {
269-
Ok(SeqVariantSerializer::new(index, variant, Some(len)))
259+
Ok(SeqVariantSerializer::new(index, Some(len)))
270260
}
271261

272262
fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
@@ -283,12 +273,12 @@ impl<'b> ser::Serializer for &'b mut Serializer {
283273

284274
fn serialize_struct_variant(
285275
self,
286-
_: &'static str,
276+
_name: &'static str,
287277
index: u32,
288-
variant: &'static str,
278+
_variant: &'static str,
289279
len: usize,
290280
) -> Result<Self::SerializeStructVariant, Self::Error> {
291-
Ok(StructVariantSerializer::new(index, variant, len))
281+
Ok(StructVariantSerializer::new(index, len))
292282
}
293283

294284
fn is_human_readable(&self) -> bool {
@@ -346,11 +336,11 @@ impl ser::SerializeTupleStruct for SeqSerializer {
346336
}
347337
}
348338

349-
impl ser::SerializeSeq for SeqVariantSerializer<'_> {
339+
impl ser::SerializeTupleVariant for SeqVariantSerializer {
350340
type Ok = Value;
351341
type Error = Error;
352342

353-
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
343+
fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
354344
where
355345
T: Serialize + ?Sized,
356346
{
@@ -362,29 +352,7 @@ impl ser::SerializeSeq for SeqVariantSerializer<'_> {
362352
}
363353

364354
fn end(self) -> Result<Self::Ok, Self::Error> {
365-
Ok(Value::Record(vec![
366-
(
367-
"type".to_owned(),
368-
Value::Enum(self.index, self.variant.to_owned()),
369-
),
370-
("value".to_owned(), Value::Array(self.items)),
371-
]))
372-
}
373-
}
374-
375-
impl ser::SerializeTupleVariant for SeqVariantSerializer<'_> {
376-
type Ok = Value;
377-
type Error = Error;
378-
379-
fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
380-
where
381-
T: Serialize + ?Sized,
382-
{
383-
ser::SerializeSeq::serialize_element(self, value)
384-
}
385-
386-
fn end(self) -> Result<Self::Ok, Self::Error> {
387-
ser::SerializeSeq::end(self)
355+
Ok(Value::Union(self.index, Box::new(Value::Array(self.items))))
388356
}
389357
}
390358

@@ -447,7 +415,7 @@ impl ser::SerializeStruct for StructSerializer {
447415
}
448416
}
449417

450-
impl ser::SerializeStructVariant for StructVariantSerializer<'_> {
418+
impl ser::SerializeStructVariant for StructVariantSerializer {
451419
type Ok = Value;
452420
type Error = Error;
453421

@@ -463,16 +431,7 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> {
463431
}
464432

465433
fn end(self) -> Result<Self::Ok, Self::Error> {
466-
Ok(Value::Record(vec![
467-
(
468-
"type".to_owned(),
469-
Value::Enum(self.index, self.variant.to_owned()),
470-
),
471-
(
472-
"value".to_owned(),
473-
Value::Union(self.index, Box::new(Value::Record(self.fields))),
474-
),
475-
]))
434+
Ok(Value::Union(self.index, Box::new(Value::Record(self.fields))))
476435
}
477436
}
478437

@@ -789,13 +748,7 @@ mod tests {
789748

790749
let expected = Value::Record(vec![(
791750
"a".to_owned(),
792-
Value::Record(vec![
793-
("type".to_owned(), Value::Enum(0, "Double".to_owned())),
794-
(
795-
"value".to_owned(),
796-
Value::Union(0, Box::new(Value::Double(64.0))),
797-
),
798-
]),
751+
Value::Union(0, Box::new(Value::Double(64.0))),
799752
)]);
800753

801754
assert_eq!(
@@ -851,19 +804,13 @@ mod tests {
851804
};
852805
let expected = Value::Record(vec![(
853806
"a".to_owned(),
854-
Value::Record(vec![
855-
("type".to_owned(), Value::Enum(0, "Val1".to_owned())),
856-
(
857-
"value".to_owned(),
858-
Value::Union(
859-
0,
860-
Box::new(Value::Record(vec![
861-
("x".to_owned(), Value::Float(1.0)),
862-
("y".to_owned(), Value::Float(2.0)),
863-
])),
864-
),
865-
),
866-
]),
807+
Value::Union(
808+
0,
809+
Box::new(Value::Record(vec![
810+
("x".to_owned(), Value::Float(1.0)),
811+
("y".to_owned(), Value::Float(2.0)),
812+
])),
813+
),
867814
)]);
868815

869816
assert_eq!(
@@ -965,17 +912,14 @@ mod tests {
965912

966913
let expected = Value::Record(vec![(
967914
"a".to_owned(),
968-
Value::Record(vec![
969-
("type".to_owned(), Value::Enum(1, "Val2".to_owned())),
970-
(
971-
"value".to_owned(),
972-
Value::Array(vec![
973-
Value::Union(1, Box::new(Value::Float(1.0))),
974-
Value::Union(1, Box::new(Value::Float(2.0))),
975-
Value::Union(1, Box::new(Value::Float(3.0))),
976-
]),
977-
),
978-
]),
915+
Value::Union(
916+
1,
917+
Box::new(Value::Array(vec![
918+
Value::Union(1, Box::new(Value::Float(1.0))),
919+
Value::Union(1, Box::new(Value::Float(2.0))),
920+
Value::Union(1, Box::new(Value::Float(3.0))),
921+
])),
922+
),
979923
)]);
980924

981925
assert_eq!(

avro/src/types.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,18 +1024,32 @@ impl Value {
10241024
enclosing_namespace: &Namespace,
10251025
field_default: &Option<JsonValue>,
10261026
) -> Result<Self, Error> {
1027-
let v = match self {
1027+
let (i, inner, v) = match self {
10281028
// Both are unions case.
1029-
Value::Union(_i, v) => *v,
1029+
Value::Union(i, v) => {
1030+
let index = i as usize;
1031+
let inner = schema.schemas.get(index)
1032+
.ok_or_else(|| Details::UnionIndexOutOfBounds {
1033+
schema: schema.clone(),
1034+
value: *v.clone(),
1035+
index,
1036+
num_variants: schema.schemas.len()
1037+
})?;
1038+
1039+
(index, inner, *v)
1040+
},
10301041
// Reader is a union, but writer is not.
1031-
v => v,
1042+
v => {
1043+
let (i, inner) = schema
1044+
.find_schema_with_known_schemata(&v, Some(names), enclosing_namespace)
1045+
.ok_or_else(|| Details::FindUnionVariant {
1046+
schema: schema.clone(),
1047+
value: v.clone(),
1048+
})?;
1049+
1050+
(i, inner, v)
1051+
},
10321052
};
1033-
let (i, inner) = schema
1034-
.find_schema_with_known_schemata(&v, Some(names), enclosing_namespace)
1035-
.ok_or_else(|| Details::FindUnionVariant {
1036-
schema: schema.clone(),
1037-
value: v.clone(),
1038-
})?;
10391053

10401054
Ok(Value::Union(
10411055
i as u32,

0 commit comments

Comments
 (0)