Skip to content

Commit b7fa60d

Browse files
Kriskras99default
andauthored
fix: Don't depend on Serde to provide fields in the right order (#351)
* fix: Different field order between Serde and the Schema * fix: Downgrade assert to debug_assert --------- Co-authored-by: default <admin@kriskras99.nl>
1 parent 1c12ae6 commit b7fa60d

File tree

3 files changed

+243
-27
lines changed

3 files changed

+243
-27
lines changed

avro/src/error.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,14 @@ pub enum Details {
516516

517517
#[error("Failed to serialize field '{field_name}' for record {record_schema:?}: {error}")]
518518
SerializeRecordFieldWithSchema {
519-
field_name: &'static str,
519+
field_name: String,
520520
record_schema: Schema,
521521
error: Box<Error>,
522522
},
523523

524+
#[error("Missing default for skipped field '{field_name}' for schema {schema:?}")]
525+
MissingDefaultForSkippedField { field_name: String, schema: Schema },
526+
524527
#[error("Failed to deserialize Avro value into value: {0}")]
525528
DeserializeValue(String),
526529

avro/src/ser_schema.rs

Lines changed: 167 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ use crate::{
2525
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
2626
};
2727
use bigdecimal::BigDecimal;
28-
use serde::{Serialize, ser};
29-
use std::{borrow::Cow, io::Write, str::FromStr};
28+
use serde::ser;
29+
use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write, str::FromStr};
3030

3131
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
3232
const COLLECTION_SERIALIZER_DEFAULT_INIT_ITEM_CAPACITY: usize = 32;
@@ -249,6 +249,9 @@ impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMap<'_, '_, W> {
249249
pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
250250
ser: &'a mut SchemaAwareWriteSerializer<'s, W>,
251251
record_schema: &'s RecordSchema,
252+
/// Fields we received in the wrong order
253+
field_cache: HashMap<usize, Vec<u8>>,
254+
field_position: usize,
252255
bytes_written: usize,
253256
}
254257

@@ -260,6 +263,8 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
260263
SchemaAwareWriteSerializeStruct {
261264
ser,
262265
record_schema,
266+
field_cache: HashMap::new(),
267+
field_position: 0,
263268
bytes_written: 0,
264269
}
265270
}
@@ -268,19 +273,85 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
268273
where
269274
T: ?Sized + ser::Serialize,
270275
{
271-
// If we receive fields in order, write them directly to the main writer
272-
let mut value_ser = SchemaAwareWriteSerializer::new(
273-
&mut *self.ser.writer,
274-
&field.schema,
275-
self.ser.names,
276-
self.ser.enclosing_namespace.clone(),
277-
);
278-
self.bytes_written += value.serialize(&mut value_ser)?;
279-
280-
Ok(())
276+
match self.field_position.cmp(&field.position) {
277+
Ordering::Equal => {
278+
// If we receive fields in order, write them directly to the main writer
279+
let mut value_ser = SchemaAwareWriteSerializer::new(
280+
&mut *self.ser.writer,
281+
&field.schema,
282+
self.ser.names,
283+
self.ser.enclosing_namespace.clone(),
284+
);
285+
self.bytes_written += value.serialize(&mut value_ser)?;
286+
287+
self.field_position += 1;
288+
while let Some(bytes) = self.field_cache.remove(&self.field_position) {
289+
self.ser
290+
.writer
291+
.write_all(&bytes)
292+
.map_err(Details::WriteBytes)?;
293+
self.bytes_written += bytes.len();
294+
self.field_position += 1;
295+
}
296+
Ok(())
297+
}
298+
Ordering::Less => {
299+
// Current field position is smaller than this field position,
300+
// so we're still missing at least one field, save this field temporarily
301+
let mut bytes = Vec::new();
302+
let mut value_ser = SchemaAwareWriteSerializer::new(
303+
&mut bytes,
304+
&field.schema,
305+
self.ser.names,
306+
self.ser.enclosing_namespace.clone(),
307+
);
308+
value.serialize(&mut value_ser)?;
309+
if self.field_cache.insert(field.position, bytes).is_some() {
310+
Err(Details::FieldNameDuplicate(field.name.clone()).into())
311+
} else {
312+
Ok(())
313+
}
314+
}
315+
Ordering::Greater => {
316+
// Current field position is greater than this field position,
317+
// so we've already had this field
318+
Err(Details::FieldNameDuplicate(field.name.clone()).into())
319+
}
320+
}
281321
}
282322

283-
fn end(self) -> Result<usize, Error> {
323+
fn end(mut self) -> Result<usize, Error> {
324+
// Write any fields that are `serde(skip)` or `serde(skip_serializing)`
325+
while self.field_position != self.record_schema.fields.len() {
326+
let field_info = &self.record_schema.fields[self.field_position];
327+
if let Some(bytes) = self.field_cache.remove(&self.field_position) {
328+
self.ser
329+
.writer
330+
.write_all(&bytes)
331+
.map_err(Details::WriteBytes)?;
332+
self.bytes_written += bytes.len();
333+
self.field_position += 1;
334+
} else if let Some(default) = &field_info.default {
335+
self.serialize_next_field(field_info, default)
336+
.map_err(|e| Details::SerializeRecordFieldWithSchema {
337+
field_name: field_info.name.clone(),
338+
record_schema: Schema::Record(self.record_schema.clone()),
339+
error: Box::new(e),
340+
})?;
341+
} else {
342+
return Err(Details::MissingDefaultForSkippedField {
343+
field_name: field_info.name.clone(),
344+
schema: Schema::Record(self.record_schema.clone()),
345+
}
346+
.into());
347+
}
348+
}
349+
350+
debug_assert!(
351+
self.field_cache.is_empty(),
352+
"There should be no more unwritten fields at this point: {:?}",
353+
self.field_cache
354+
);
284355
Ok(self.bytes_written)
285356
}
286357
}
@@ -304,7 +375,7 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
304375
// self.item_count += 1;
305376
self.serialize_next_field(field, value).map_err(|e| {
306377
Details::SerializeRecordFieldWithSchema {
307-
field_name: key,
378+
field_name: key.to_string(),
308379
record_schema: Schema::Record(self.record_schema.clone()),
309380
error: Box::new(e),
310381
}
@@ -323,15 +394,20 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
323394
.and_then(|idx| self.record_schema.fields.get(*idx));
324395

325396
if let Some(skipped_field) = skipped_field {
326-
// self.item_count += 1;
327-
skipped_field
328-
.default
329-
.serialize(&mut SchemaAwareWriteSerializer::new(
330-
self.ser.writer,
331-
&skipped_field.schema,
332-
self.ser.names,
333-
self.ser.enclosing_namespace.clone(),
334-
))?;
397+
if let Some(default) = &skipped_field.default {
398+
self.serialize_next_field(skipped_field, default)
399+
.map_err(|e| Details::SerializeRecordFieldWithSchema {
400+
field_name: key.to_string(),
401+
record_schema: Schema::Record(self.record_schema.clone()),
402+
error: Box::new(e),
403+
})?;
404+
} else {
405+
return Err(Details::MissingDefaultForSkippedField {
406+
field_name: key.to_string(),
407+
schema: Schema::Record(self.record_schema.clone()),
408+
}
409+
.into());
410+
}
335411
} else {
336412
return Err(Details::GetField(key.to_string()).into());
337413
}
@@ -1741,12 +1817,13 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s
17411817
mod tests {
17421818
use super::*;
17431819
use crate::{
1744-
Days, Duration, Millis, Months, decimal::Decimal, error::Details, schema::ResolvedSchema,
1820+
Days, Duration, Millis, Months, Reader, Writer, decimal::Decimal, error::Details,
1821+
from_value, schema::ResolvedSchema,
17451822
};
17461823
use apache_avro_test_helper::TestResult;
17471824
use bigdecimal::BigDecimal;
17481825
use num_bigint::{BigInt, Sign};
1749-
use serde::Serialize;
1826+
use serde::{Deserialize, Serialize};
17501827
use serde_bytes::{ByteArray, Bytes};
17511828
use std::{
17521829
collections::{BTreeMap, HashMap},
@@ -2900,4 +2977,69 @@ mod tests {
29002977
string_record.serialize(&mut serializer)?;
29012978
Ok(())
29022979
}
2980+
2981+
#[test]
2982+
fn avro_rs_351_different_field_order_serde_vs_schema() -> TestResult {
2983+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2984+
struct Foo {
2985+
a: String,
2986+
b: String,
2987+
c: usize,
2988+
d: f64,
2989+
e: usize,
2990+
}
2991+
let schema = Schema::parse_str(
2992+
r#"
2993+
{
2994+
"type":"record",
2995+
"name":"Foo",
2996+
"fields": [
2997+
{
2998+
"name":"b",
2999+
"type":"string"
3000+
},
3001+
{
3002+
"name":"a",
3003+
"type":"string"
3004+
},
3005+
{
3006+
"name":"d",
3007+
"type":"double"
3008+
},
3009+
{
3010+
"name":"e",
3011+
"type":"long"
3012+
},
3013+
{
3014+
"name":"c",
3015+
"type":"long"
3016+
}
3017+
]
3018+
}
3019+
"#,
3020+
)?;
3021+
3022+
let mut writer = Writer::new(&schema, Vec::new())?;
3023+
writer.append_ser(Foo {
3024+
a: "Hello".into(),
3025+
b: "World".into(),
3026+
c: 42,
3027+
d: std::f64::consts::PI,
3028+
e: 5,
3029+
})?;
3030+
let encoded = writer.into_inner()?;
3031+
let mut reader = Reader::with_schema(&schema, &encoded[..])?;
3032+
let decoded = from_value::<Foo>(&reader.next().unwrap()?)?;
3033+
assert_eq!(
3034+
decoded,
3035+
Foo {
3036+
a: "Hello".into(),
3037+
b: "World".into(),
3038+
c: 42,
3039+
d: std::f64::consts::PI,
3040+
e: 5
3041+
}
3042+
);
3043+
Ok(())
3044+
}
29033045
}

avro/tests/avro-rs-226.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field
4545
struct T {
4646
x: Option<i8>,
4747
#[serde(skip_serializing_if = "Option::is_none")]
48+
#[avro(default = "null")]
4849
y: Option<String>,
4950
z: Option<i8>,
5051
}
@@ -64,6 +65,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field(
6465
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
6566
struct T {
6667
#[serde(skip_serializing_if = "Option::is_none")]
68+
#[avro(default = "null")]
6769
x: Option<i8>,
6870
y: Option<String>,
6971
z: Option<i8>,
@@ -86,6 +88,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_last_field()
8688
x: Option<i8>,
8789
y: Option<String>,
8890
#[serde(skip_serializing_if = "Option::is_none")]
91+
#[avro(default = "null")]
8992
z: Option<i8>,
9093
}
9194

@@ -100,18 +103,20 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_last_field()
100103
}
101104

102105
#[test]
103-
#[ignore = "This test should be re-enabled once the serde-driven deserialization is implemented! PR #227"]
104106
fn avro_rs_226_index_out_of_bounds_with_serde_skip_multiple_fields() -> TestResult {
105107
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
106108
struct T {
107109
no_skip1: Option<i8>,
108110
#[serde(skip_serializing)]
111+
#[avro(default = "null")]
109112
skip_serializing: Option<String>,
110113
#[serde(skip_serializing_if = "Option::is_none")]
114+
#[avro(default = "null")]
111115
skip_serializing_if: Option<i8>,
112116
#[serde(skip_deserializing)]
113117
skip_deserializing: Option<String>,
114118
#[serde(skip)]
119+
#[avro(skip)]
115120
skip: Option<String>,
116121
no_skip2: Option<i8>,
117122
}
@@ -128,3 +133,69 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_multiple_fields() -> TestResu
128133
},
129134
)
130135
}
136+
137+
#[test]
138+
#[should_panic(expected = "Missing default for skipped field 'y' for schema")]
139+
fn avro_rs_351_no_default_for_serde_skip_serializing_if_should_panic() {
140+
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
141+
struct T {
142+
x: Option<i8>,
143+
#[serde(skip_serializing_if = "Option::is_none")]
144+
y: Option<String>,
145+
z: Option<i8>,
146+
}
147+
148+
ser_deser::<T>(
149+
&T::get_schema(),
150+
T {
151+
x: None,
152+
y: None,
153+
z: Some(1),
154+
},
155+
)
156+
.unwrap()
157+
}
158+
159+
#[test]
160+
#[should_panic(expected = "Missing default for skipped field 'x' for schema")]
161+
fn avro_rs_351_no_default_for_serde_skip_should_panic() {
162+
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
163+
struct T {
164+
#[serde(skip)]
165+
x: Option<i8>,
166+
y: Option<String>,
167+
z: Option<i8>,
168+
}
169+
170+
ser_deser::<T>(
171+
&T::get_schema(),
172+
T {
173+
x: None,
174+
y: None,
175+
z: Some(1),
176+
},
177+
)
178+
.unwrap()
179+
}
180+
181+
#[test]
182+
#[should_panic(expected = "Missing default for skipped field 'z' for schema")]
183+
fn avro_rs_351_no_default_for_serde_skip_serializing_should_panic() {
184+
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
185+
struct T {
186+
x: Option<i8>,
187+
y: Option<String>,
188+
#[serde(skip_serializing)]
189+
z: Option<i8>,
190+
}
191+
192+
ser_deser::<T>(
193+
&T::get_schema(),
194+
T {
195+
x: Some(0),
196+
y: None,
197+
z: None,
198+
},
199+
)
200+
.unwrap()
201+
}

0 commit comments

Comments
 (0)