Skip to content

Commit a186161

Browse files
committed
fix: Different field order between Serde and the Schema
1 parent 98d6caf commit a186161

File tree

3 files changed

+222
-24
lines changed

3 files changed

+222
-24
lines changed

avro/src/error.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,14 @@ pub enum Details {
516516

517517
#[error("Failed to serialize field '{field_name}' for record {record_schema:?}: {error}")]
518518
SerializeRecordFieldWithSchema {
519-
field_name: &'static str,
519+
field_name: String,
520520
record_schema: Schema,
521521
error: Box<Error>,
522522
},
523523

524+
#[error("Missing default for skipped field '{field_name}' for schema {schema:?}")]
525+
MissingDefaultForSkippedField { field_name: String, schema: Schema },
526+
524527
#[error("Failed to deserialize Avro value into value: {0}")]
525528
DeserializeValue(String),
526529

avro/src/ser_schema.rs

Lines changed: 146 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
2626
};
2727
use bigdecimal::BigDecimal;
28-
use serde::{Serialize, ser};
28+
use serde::ser;
2929
use std::{borrow::Cow, io::Write, str::FromStr};
3030

3131
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -249,6 +249,9 @@ impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMap<'_, '_, W> {
249249
pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
250250
ser: &'a mut SchemaAwareWriteSerializer<'s, W>,
251251
record_schema: &'s RecordSchema,
252+
/// Fields we received in the wrong order
253+
field_cache: Vec<(usize, Vec<u8>)>,
254+
next_field: usize,
252255
bytes_written: usize,
253256
}
254257

@@ -260,6 +263,8 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
260263
SchemaAwareWriteSerializeStruct {
261264
ser,
262265
record_schema,
266+
field_cache: Vec::new(),
267+
next_field: 0,
263268
bytes_written: 0,
264269
}
265270
}
@@ -268,19 +273,86 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
268273
where
269274
T: ?Sized + ser::Serialize,
270275
{
271-
// If we receive fields in order, write them directly to the main writer
272-
let mut value_ser = SchemaAwareWriteSerializer::new(
273-
&mut *self.ser.writer,
274-
&field.schema,
275-
self.ser.names,
276-
self.ser.enclosing_namespace.clone(),
277-
);
278-
self.bytes_written += value.serialize(&mut value_ser)?;
276+
if field.position == self.next_field {
277+
// If we receive fields in order, write them directly to the main writer
278+
let mut value_ser = SchemaAwareWriteSerializer::new(
279+
&mut *self.ser.writer,
280+
&field.schema,
281+
self.ser.names,
282+
self.ser.enclosing_namespace.clone(),
283+
);
284+
self.bytes_written += value.serialize(&mut value_ser)?;
279285

286+
self.next_field += 1;
287+
while let Some(index) = self
288+
.field_cache
289+
.iter()
290+
.position(|(pos, _)| pos == &self.next_field)
291+
{
292+
let (_, bytes) = self.field_cache.remove(index);
293+
self.ser
294+
.writer
295+
.write_all(&bytes)
296+
.map_err(Details::WriteBytes)?;
297+
self.bytes_written += bytes.len();
298+
self.next_field += 1;
299+
}
300+
} else {
301+
// This field is in the wrong order, write it to a temporary buffer so we can add it at the right time
302+
let mut bytes = Vec::new();
303+
let mut value_ser = SchemaAwareWriteSerializer::new(
304+
&mut bytes,
305+
&field.schema,
306+
self.ser.names,
307+
self.ser.enclosing_namespace.clone(),
308+
);
309+
value.serialize(&mut value_ser)?;
310+
self.field_cache.push((field.position, bytes));
311+
}
280312
Ok(())
281313
}
282314

283-
fn end(self) -> Result<usize, Error> {
315+
fn end(mut self) -> Result<usize, Error> {
316+
// Write any fields that are `serde(skip)` or `serde(skip_serializing)`
317+
while self.next_field != self.record_schema.fields.len() {
318+
let field_info = &self.record_schema.fields[self.next_field];
319+
if let Some(index) = self
320+
.field_cache
321+
.iter()
322+
.position(|(pos, _)| pos == &self.next_field)
323+
{
324+
let (_, bytes) = self.field_cache.remove(index);
325+
self.ser
326+
.writer
327+
.write_all(&bytes)
328+
.map_err(Details::WriteBytes)?;
329+
self.bytes_written += bytes.len();
330+
self.next_field += 1;
331+
} else if let Some(default) = &field_info.default {
332+
self.serialize_next_field(field_info, default)
333+
.map_err(|e| Details::SerializeRecordFieldWithSchema {
334+
field_name: field_info.name.clone(),
335+
record_schema: Schema::Record(self.record_schema.clone()),
336+
error: Box::new(e),
337+
})?;
338+
} else {
339+
return Err(Details::MissingDefaultForSkippedField {
340+
field_name: field_info.name.clone(),
341+
schema: Schema::Record(self.record_schema.clone()),
342+
}
343+
.into());
344+
}
345+
}
346+
347+
// Check if all fields were written
348+
if self.next_field != self.record_schema.fields.len() {
349+
let name = self.record_schema.fields[self.next_field].name.clone();
350+
return Err(Details::GetField(name).into());
351+
}
352+
assert!(
353+
self.field_cache.is_empty(),
354+
"There should be no more unwritten fields at this point"
355+
);
284356
Ok(self.bytes_written)
285357
}
286358
}
@@ -304,7 +376,7 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
304376
// self.item_count += 1;
305377
self.serialize_next_field(field, value).map_err(|e| {
306378
Details::SerializeRecordFieldWithSchema {
307-
field_name: key,
379+
field_name: key.to_string(),
308380
record_schema: Schema::Record(self.record_schema.clone()),
309381
error: Box::new(e),
310382
}
@@ -323,15 +395,20 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
323395
.and_then(|idx| self.record_schema.fields.get(*idx));
324396

325397
if let Some(skipped_field) = skipped_field {
326-
// self.item_count += 1;
327-
skipped_field
328-
.default
329-
.serialize(&mut SchemaAwareWriteSerializer::new(
330-
self.ser.writer,
331-
&skipped_field.schema,
332-
self.ser.names,
333-
self.ser.enclosing_namespace.clone(),
334-
))?;
398+
if let Some(default) = &skipped_field.default {
399+
self.serialize_next_field(skipped_field, default)
400+
.map_err(|e| Details::SerializeRecordFieldWithSchema {
401+
field_name: key.to_string(),
402+
record_schema: Schema::Record(self.record_schema.clone()),
403+
error: Box::new(e),
404+
})?;
405+
} else {
406+
return Err(Details::MissingDefaultForSkippedField {
407+
field_name: key.to_string(),
408+
schema: Schema::Record(self.record_schema.clone()),
409+
}
410+
.into());
411+
}
335412
} else {
336413
return Err(Details::GetField(key.to_string()).into());
337414
}
@@ -1741,12 +1818,13 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s
17411818
mod tests {
17421819
use super::*;
17431820
use crate::{
1744-
Days, Duration, Millis, Months, decimal::Decimal, error::Details, schema::ResolvedSchema,
1821+
Days, Duration, Millis, Months, Reader, Writer, decimal::Decimal, error::Details,
1822+
from_value, schema::ResolvedSchema,
17451823
};
17461824
use apache_avro_test_helper::TestResult;
17471825
use bigdecimal::BigDecimal;
17481826
use num_bigint::{BigInt, Sign};
1749-
use serde::Serialize;
1827+
use serde::{Deserialize, Serialize};
17501828
use serde_bytes::{ByteArray, Bytes};
17511829
use std::{
17521830
collections::{BTreeMap, HashMap},
@@ -2900,4 +2978,50 @@ mod tests {
29002978
string_record.serialize(&mut serializer)?;
29012979
Ok(())
29022980
}
2981+
2982+
#[test]
2983+
fn different_field_order_serde_vs_schema() -> TestResult {
2984+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2985+
struct Foo {
2986+
a: String,
2987+
b: String,
2988+
}
2989+
let schema = Schema::parse_str(
2990+
r#"
2991+
{
2992+
"type":"record",
2993+
"name":"Foo",
2994+
"fields": [
2995+
{
2996+
"name":"b",
2997+
"type":"string"
2998+
},
2999+
{
3000+
"name":"a",
3001+
"type":"string"
3002+
}
3003+
]
3004+
}
3005+
"#,
3006+
)?;
3007+
3008+
let mut writer = Writer::new(&schema, Vec::new())?;
3009+
if let Err(e) = writer.append_ser(Foo {
3010+
a: "Hello".into(),
3011+
b: "World".into(),
3012+
}) {
3013+
panic!("{e:?}");
3014+
}
3015+
let encoded = writer.into_inner()?;
3016+
let mut reader = Reader::with_schema(&schema, &encoded[..])?;
3017+
let decoded = from_value::<Foo>(&reader.next().unwrap()?)?;
3018+
assert_eq!(
3019+
decoded,
3020+
Foo {
3021+
a: "Hello".into(),
3022+
b: "World".into(),
3023+
}
3024+
);
3025+
Ok(())
3026+
}
29033027
}

avro/tests/avro-rs-226.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field
4545
struct T {
4646
x: Option<i8>,
4747
#[serde(skip_serializing_if = "Option::is_none")]
48+
#[avro(default = "null")]
4849
y: Option<String>,
4950
z: Option<i8>,
5051
}
@@ -64,6 +65,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field(
6465
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
6566
struct T {
6667
#[serde(skip_serializing_if = "Option::is_none")]
68+
#[avro(default = "null")]
6769
x: Option<i8>,
6870
y: Option<String>,
6971
z: Option<i8>,
@@ -86,6 +88,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_last_field()
8688
x: Option<i8>,
8789
y: Option<String>,
8890
#[serde(skip_serializing_if = "Option::is_none")]
91+
#[avro(default = "null")]
8992
z: Option<i8>,
9093
}
9194

@@ -100,18 +103,20 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_last_field()
100103
}
101104

102105
#[test]
103-
#[ignore = "This test should be re-enabled once the serde-driven deserialization is implemented! PR #227"]
104106
fn avro_rs_226_index_out_of_bounds_with_serde_skip_multiple_fields() -> TestResult {
105107
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
106108
struct T {
107109
no_skip1: Option<i8>,
108110
#[serde(skip_serializing)]
111+
#[avro(default = "null")]
109112
skip_serializing: Option<String>,
110113
#[serde(skip_serializing_if = "Option::is_none")]
114+
#[avro(default = "null")]
111115
skip_serializing_if: Option<i8>,
112116
#[serde(skip_deserializing)]
113117
skip_deserializing: Option<String>,
114118
#[serde(skip)]
119+
#[avro(skip)]
115120
skip: Option<String>,
116121
no_skip2: Option<i8>,
117122
}
@@ -128,3 +133,69 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_multiple_fields() -> TestResu
128133
},
129134
)
130135
}
136+
137+
#[test]
138+
#[should_panic(expected = "Missing default for skipped field 'y' for schema")]
139+
fn avro_rs_351_no_default_for_serde_skip_serializing_if_should_panic() {
140+
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
141+
struct T {
142+
x: Option<i8>,
143+
#[serde(skip_serializing_if = "Option::is_none")]
144+
y: Option<String>,
145+
z: Option<i8>,
146+
}
147+
148+
ser_deser::<T>(
149+
&T::get_schema(),
150+
T {
151+
x: None,
152+
y: None,
153+
z: Some(1),
154+
},
155+
)
156+
.unwrap()
157+
}
158+
159+
#[test]
160+
#[should_panic(expected = "Missing default for skipped field 'x' for schema")]
161+
fn avro_rs_351_no_default_for_serde_skip_should_panic() {
162+
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
163+
struct T {
164+
#[serde(skip)]
165+
x: Option<i8>,
166+
y: Option<String>,
167+
z: Option<i8>,
168+
}
169+
170+
ser_deser::<T>(
171+
&T::get_schema(),
172+
T {
173+
x: None,
174+
y: None,
175+
z: Some(1),
176+
},
177+
)
178+
.unwrap()
179+
}
180+
181+
#[test]
182+
#[should_panic(expected = "Missing default for skipped field 'z' for schema")]
183+
fn avro_rs_351_no_default_for_serde_skip_serializing_should_panic() {
184+
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
185+
struct T {
186+
x: Option<i8>,
187+
y: Option<String>,
188+
#[serde(skip_serializing)]
189+
z: Option<i8>,
190+
}
191+
192+
ser_deser::<T>(
193+
&T::get_schema(),
194+
T {
195+
x: Some(0),
196+
y: None,
197+
z: None,
198+
},
199+
)
200+
.unwrap()
201+
}

0 commit comments

Comments
 (0)