Skip to content

Commit 28acac3

Browse files
PookieBunsdefault
authored andcommitted
feat: Use index bumping for flatten Option for deser
1 parent 8dbc4db commit 28acac3

File tree

4 files changed

+79
-149
lines changed

4 files changed

+79
-149
lines changed

avro/src/serde/deser_schema/enums.rs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::{borrow::Borrow, io::Read};
19+
1820
use serde::{
1921
Deserializer,
2022
de::{DeserializeSeed, EnumAccess, Unexpected, VariantAccess, Visitor},
2123
};
22-
use std::{borrow::{Borrow, Cow}, io::Read};
2324

2425
use super::{Config, DESERIALIZE_ANY, SchemaAwareDeserializer, identifier::IdentifierDeserializer};
2526
use crate::{
2627
Error, Schema,
2728
error::Details,
28-
schema::EnumSchema,
29+
schema::{EnumSchema, UnionSchema},
2930
util::zag_i32,
3031
};
3132

@@ -99,21 +100,31 @@ impl<'de, 's, 'r, R: Read> VariantAccess<'de> for PlainEnumDeserializer<'s, 'r,
99100

100101
pub struct UnionEnumDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
101102
reader: &'r mut R,
102-
schema: &'s Schema,
103+
variants: &'s [Schema],
103104
config: Config<'s, S>,
105+
branch_index: Option<usize>,
104106
}
105107

106108
impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionEnumDeserializer<'s, 'r, R, S> {
107109
pub fn new(
108110
reader: &'r mut R,
109-
schema: &'s Schema,
111+
schema: &'s UnionSchema,
110112
config: Config<'s, S>,
111-
) -> Result<Self, Error> {
112-
Ok(Self {
113+
branch_index: Option<usize>,
114+
) -> Self {
115+
Self {
113116
reader,
114-
schema,
117+
variants: schema.variants(),
115118
config,
116-
})
119+
branch_index,
120+
}
121+
}
122+
123+
fn get_variant_index(&self, branch_index: usize) -> usize {
124+
match self.branch_index {
125+
Some(null_index) if branch_index >= null_index => branch_index - 1,
126+
_ => branch_index,
127+
}
117128
}
118129
}
119130

@@ -127,13 +138,21 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> EnumAccess<'de>
127138
where
128139
V: DeserializeSeed<'de>,
129140
{
130-
let name = match self.schema.name() {
131-
Some(name) => Cow::Borrowed(name.name()),
132-
None => Cow::Owned(self.schema.to_string()),
141+
let index = match self.branch_index {
142+
Some(index) => index,
143+
None => {
144+
let index = zag_i32(self.reader)?;
145+
usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?
146+
}
133147
};
148+
let schema = self.variants.get(index).ok_or(Details::GetUnionVariant {
149+
index: index as i64,
150+
num_variants: self.variants.len(),
151+
})?;
152+
let variant_index = self.get_variant_index(index);
134153
Ok((
135-
seed.deserialize(IdentifierDeserializer::string(&name))?,
136-
UnionVariantAccess::new(self.schema, self.reader, self.config)?,
154+
seed.deserialize(IdentifierDeserializer::index(variant_index as u32))?,
155+
UnionVariantAccess::new(schema, self.reader, self.config)?,
137156
))
138157
}
139158
}

avro/src/serde/deser_schema/mod.rs

Lines changed: 20 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ pub struct SchemaAwareDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
7878
/// This schema is guaranteed to not be a [`Schema::Ref`].
7979
schema: &'s Schema,
8080
config: Config<'s, S>,
81+
branch_index: Option<usize>,
8182
}
8283

8384
impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> {
@@ -95,12 +96,14 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> {
9596
reader,
9697
schema,
9798
config,
99+
branch_index: None,
98100
})
99101
} else {
100102
Ok(Self {
101103
reader,
102104
schema,
103105
config,
106+
branch_index: None,
104107
})
105108
}
106109
}
@@ -128,18 +131,22 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> {
128131
Ok(self)
129132
}
130133

131-
fn with_nullable_union_three_plus_variants(
132-
self,
133-
) -> ThreePlusVariantUnionDeserializer<'s, 'r, R, S> {
134-
ThreePlusVariantUnionDeserializer::new(self)
134+
fn with_branch_index(mut self, branch_index: usize) -> Self {
135+
self.branch_index = Some(branch_index);
136+
self
135137
}
136138

137139
/// Read the union and create a new deserializer with the existing reader and config.
138140
///
139141
/// This will resolve the read schema if it is a reference.
140142
fn with_union(self, schema: &'s UnionSchema) -> Result<Self, Error> {
141-
let index = zag_i32(self.reader)?;
142-
let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
143+
let index = match self.branch_index {
144+
Some(index) => index,
145+
None => {
146+
let index = zag_i32(self.reader)?;
147+
usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?
148+
}
149+
};
143150
let variant = schema.get_variant(index)?;
144151
self.with_different_schema(variant)
145152
}
@@ -540,10 +547,7 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
540547
if union.variants().len() == 2 {
541548
visitor.visit_some(self.with_different_schema(schema)?)
542549
} else {
543-
visitor.visit_some(
544-
self.with_different_schema(schema)?
545-
.with_nullable_union_three_plus_variants(),
546-
)
550+
visitor.visit_some(self.with_branch_index(index))
547551
}
548552
}
549553
} else {
@@ -719,17 +723,12 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
719723
Schema::Enum(schema) => {
720724
visitor.visit_enum(PlainEnumDeserializer::new(self.reader, schema))
721725
}
722-
Schema::Union(union) => {
723-
let index = zag_i32(self.reader)?;
724-
let index =
725-
usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
726-
let schema = union.get_variant(index)?;
727-
visitor.visit_enum(UnionEnumDeserializer::new(
728-
self.reader,
729-
schema,
730-
self.config,
731-
)?)
732-
}
726+
Schema::Union(union) => visitor.visit_enum(UnionEnumDeserializer::new(
727+
self.reader,
728+
union,
729+
self.config,
730+
self.branch_index,
731+
)),
733732
_ => Err(self.error("enum", "Expected Schema::Enum | Schema::Union")),
734733
}
735734
}
@@ -757,88 +756,6 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
757756
}
758757
}
759758

760-
struct ThreePlusVariantUnionDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
761-
inner: SchemaAwareDeserializer<'s, 'r, R, S>,
762-
}
763-
764-
impl<'s, 'r, R: Read, S: Borrow<Schema>> ThreePlusVariantUnionDeserializer<'s, 'r, R, S> {
765-
fn new(inner: SchemaAwareDeserializer<'s, 'r, R, S>) -> Self {
766-
Self { inner }
767-
}
768-
}
769-
770-
macro_rules! forward_to_inner_deserializer {
771-
($( $method:ident($($arg:ident: $arg_ty:ty),*); )*) => {
772-
$(
773-
fn $method<V>(self, $($arg: $arg_ty,)* visitor: V) -> Result<V::Value, Self::Error>
774-
where
775-
V: Visitor<'de>,
776-
{
777-
self.inner.$method($($arg,)* visitor)
778-
}
779-
)*
780-
};
781-
}
782-
783-
impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
784-
for ThreePlusVariantUnionDeserializer<'s, 'r, R, S>
785-
{
786-
type Error = Error;
787-
788-
fn deserialize_enum<V>(
789-
self,
790-
_name: &'static str,
791-
_variants: &'static [&'static str],
792-
visitor: V,
793-
) -> Result<V::Value, Self::Error>
794-
where
795-
V: Visitor<'de>,
796-
{
797-
visitor.visit_enum(UnionEnumDeserializer::new(
798-
self.inner.reader,
799-
self.inner.schema,
800-
self.inner.config,
801-
)?)
802-
}
803-
804-
fn is_human_readable(&self) -> bool {
805-
self.inner.config.human_readable
806-
}
807-
808-
forward_to_inner_deserializer! {
809-
deserialize_any();
810-
deserialize_bool();
811-
deserialize_i8();
812-
deserialize_i16();
813-
deserialize_i32();
814-
deserialize_i64();
815-
deserialize_i128();
816-
deserialize_u8();
817-
deserialize_u16();
818-
deserialize_u32();
819-
deserialize_u64();
820-
deserialize_u128();
821-
deserialize_f32();
822-
deserialize_f64();
823-
deserialize_char();
824-
deserialize_str();
825-
deserialize_string();
826-
deserialize_bytes();
827-
deserialize_byte_buf();
828-
deserialize_option();
829-
deserialize_unit();
830-
deserialize_seq();
831-
deserialize_map();
832-
deserialize_identifier();
833-
deserialize_ignored_any();
834-
deserialize_unit_struct(name: &'static str);
835-
deserialize_newtype_struct(name: &'static str);
836-
deserialize_tuple(len: usize);
837-
deserialize_tuple_struct(name: &'static str, len: usize);
838-
deserialize_struct(name: &'static str, fields: &'static [&'static str]);
839-
}
840-
}
841-
842759
#[cfg(test)]
843760
mod tests {
844761
use std::fmt::Debug;

0 commit comments

Comments
 (0)