From 773e158cd4911b9c127c83edfa7f7382138cfe87 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Wed, 19 Jun 2019 17:01:59 +0200 Subject: [PATCH 1/2] feat: implement simple tagging For now behind a feature flags `tags` --- Cargo.toml | 2 + src/de.rs | 217 +++++++++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 9 +++ src/ser.rs | 75 +++++++++++++++++- src/tag.rs | 40 ++++++++++ tests/de.rs | 15 ++++ tests/ser.rs | 65 +++++++++++++++ 7 files changed, 403 insertions(+), 20 deletions(-) create mode 100644 src/tag.rs diff --git a/Cargo.toml b/Cargo.toml index 18088d53..e3ca232b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ maintenance = { status = "actively-developed" } byteorder = { version = "1.0.0", default-features = false } half = "1.2.0" serde = { version = "1.0.14", default-features = false } +serde_derive = { version = "1.0.14", default-features = false, optional = true } [dev-dependencies] serde_derive = { version = "1.0.14", default-features = false } @@ -31,3 +32,4 @@ default = ["std"] alloc = ["serde/alloc"] std = ["serde/std" ] unsealed_read_write = [] +tags = ["serde_derive"] \ No newline at end of file diff --git a/src/de.rs b/src/de.rs index 69cb4564..6dc52533 100644 --- a/src/de.rs +++ b/src/de.rs @@ -439,6 +439,51 @@ where }) } + #[cfg(feature = "tags")] + fn parse_tag(&mut self, typ: TagTyp, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.recursion_checked(|de| { + let mut len = 2; + let value = visitor.visit_seq(TagAccess { + de, + typ, + len: &mut len, + })?; + + if len != 0 { + Err(de.error(ErrorCode::TrailingData)) + } else { + Ok(value) + } + }) + } + + #[cfg(not(feature = "tags"))] + fn parse_tag(&mut self, typ: TagTyp, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match typ { + TagTyp::U8 => { + self.parse_u8()?; + } + TagTyp::U16 => { + self.parse_u16()?; + } + TagTyp::U32 => { + self.parse_u32()?; + } + TagTyp::U64 => { + self.parse_u64()?; + } + _ => {} + } + + self.parse_value(visitor) + } + fn parse_map(&mut self, mut len: usize, visitor: V) -> Result where V: de::Visitor<'de>, @@ -704,23 +749,11 @@ where 0xbf => self.parse_indefinite_map(visitor), // Major type 6: optional semantic tagging of other major types - 0xc0..=0xd7 => self.parse_value(visitor), - 0xd8 => { - self.parse_u8()?; - self.parse_value(visitor) - } - 0xd9 => { - self.parse_u16()?; - self.parse_value(visitor) - } - 0xda => { - self.parse_u32()?; - self.parse_value(visitor) - } - 0xdb => { - self.parse_u64()?; - self.parse_value(visitor) - } + val @ 0xc0..=0xd7 => self.parse_tag(TagTyp::Inline(val), visitor), + 0xd8 => self.parse_tag(TagTyp::U8, visitor), + 0xd9 => self.parse_tag(TagTyp::U16, visitor), + 0xda => self.parse_tag(TagTyp::U32, visitor), + 0xdb => self.parse_tag(TagTyp::U64, visitor), 0xdc..=0xdf => Err(self.error(ErrorCode::UnassignedCode)), // Major type 7: floating-point numbers and other simple data types that need no content @@ -914,6 +947,156 @@ where } } +#[cfg(feature = "tags")] +struct TagAccess<'a, R> { + de: &'a mut Deserializer, + typ: TagTyp, + len: &'a mut usize, +} + +#[derive(Debug, Clone, Copy)] +enum TagTyp { + Inline(u8), + U8, + U16, + U32, + U64, +} + +#[cfg(feature = "tags")] +impl<'de, 'a, R> de::SeqAccess<'de> for TagAccess<'a, R> +where + R: Read<'de>, +{ + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + if *self.len == 0 { + return Ok(None); + } + + if *self.len == 1 { + *self.len -= 1; + // actual value + let value = seed.deserialize(&mut *self.de)?; + return Ok(Some(value)); + } + + if *self.len == 2 { + *self.len -= 1; + // TAG + let mut td = TagDeserializer { + inner: &mut *self.de, + typ: self.typ, + }; + let tag = seed.deserialize(&mut td)?; + + return Ok(Some(tag)); + } + + unreachable!(); + } + + fn size_hint(&self) -> Option { + Some(*self.len) + } +} + +#[cfg(feature = "tags")] +struct TagDeserializer<'a, R> { + inner: &'a mut Deserializer, + typ: TagTyp, +} + +#[cfg(feature = "tags")] +impl<'de, 'a, R> MakeError for TagDeserializer<'a, R> +where + R: Read<'de>, +{ + fn error(&self, code: ErrorCode) -> Error { + self.inner.error(code) + } +} + +#[cfg(feature = "tags")] +impl<'de, 'a, R> de::Deserializer<'de> for &'a mut TagDeserializer<'a, R> +where + R: Read<'de>, +{ + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_u64(visitor) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.typ { + TagTyp::U8 => self.deserialize_u64(visitor), + _ => Err(self.error(ErrorCode::InvalidUtf8)), + } + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.typ { + TagTyp::U16 => self.deserialize_u64(visitor), + _ => Err(self.error(ErrorCode::InvalidUtf8)), + } + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.typ { + TagTyp::U32 => self.deserialize_u64(visitor), + _ => Err(self.error(ErrorCode::InvalidUtf8)), + } + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let tag = match self.typ { + TagTyp::Inline(val) => (val - 0xc0) as u64, + TagTyp::U8 => self.inner.parse_u8()? as u64, + TagTyp::U16 => self.inner.parse_u16()? as u64, + TagTyp::U32 => self.inner.parse_u32()? as u64, + TagTyp::U64 => self.inner.parse_u64()?, + }; + + visitor.visit_u64(tag) + } + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u128 f32 f64 char str string unit + unit_struct seq tuple tuple_struct map struct identifier ignored_any + bytes byte_buf option newtype_struct enum + } +} + +#[cfg(feature = "tags")] +impl<'de, 'a, R> MakeError for TagAccess<'a, R> +where + R: Read<'de>, +{ + fn error(&self, code: ErrorCode) -> Error { + self.de.error(code) + } +} + struct IndefiniteSeqAccess<'a, R> { de: &'a mut Deserializer, } diff --git a/src/lib.rs b/src/lib.rs index 749cdeb9..91efc91d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -267,10 +267,16 @@ extern crate std; #[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "tags")] +#[macro_use] +extern crate serde_derive; + pub mod de; pub mod error; mod read; pub mod ser; +#[cfg(feature = "tags")] +mod tag; mod write; #[cfg(feature = "std")] @@ -308,3 +314,6 @@ pub use crate::ser::to_writer; #[cfg(feature = "std")] #[doc(inline)] pub use crate::value::Value; + +#[cfg(feature = "tags")] +pub use tag::EncodeCborTag; diff --git a/src/ser.rs b/src/ser.rs index 13d6f07b..8abbd70c 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -52,6 +52,8 @@ pub struct Serializer { writer: W, packed: bool, enum_as_map: bool, + #[cfg(feature = "tags")] + tag: bool, } impl Serializer @@ -67,6 +69,8 @@ where writer, packed: false, enum_as_map: true, + #[cfg(feature = "tags")] + tag: false, } } @@ -125,9 +129,7 @@ where /// without further information. #[inline] pub fn self_describe(&mut self) -> Result<()> { - let mut buf = [6 << 5 | 25, 0, 0]; - BigEndian::write_u16(&mut buf[1..], 55799); - self.writer.write_all(&buf).map_err(|e| e.into()) + self.write_tag(55799) } /// Unwrap the `Writer` from the `Serializer`. @@ -180,6 +182,11 @@ where } } + #[inline] + fn write_tag(&mut self, tag: u64) -> Result<()> { + self.write_u64(6, tag) + } + #[inline] fn serialize_collection<'a>( &'a mut self, @@ -293,6 +300,16 @@ where self.write_u32(0, value) } + #[cfg(feature = "tags")] + #[inline] + fn serialize_u64(self, value: u64) -> Result<()> { + if self.tag { + self.write_tag(value) + } else { + self.write_u64(0, value) + } + } + #[cfg(not(feature = "tags"))] #[inline] fn serialize_u64(self, value: u64) -> Result<()> { self.write_u64(0, value) @@ -485,9 +502,28 @@ where Ok(()) } + #[cfg(feature = "tags")] + #[inline] + fn serialize_struct(self, name: &'static str, len: usize) -> Result> { + let tagged = if name == "EncodeCborTag" { + true + } else { + self.write_u64(5, len as u64)?; + false + }; + + Ok(StructSerializer { + ser: self, + idx: 0, + tagged, + }) + } + + #[cfg(not(feature = "tags"))] #[inline] fn serialize_struct(self, _name: &'static str, len: usize) -> Result> { self.write_u64(5, len as u64)?; + Ok(StructSerializer { ser: self, idx: 0 }) } @@ -581,12 +617,45 @@ where pub struct StructSerializer<'a, W> { ser: &'a mut Serializer, idx: u32, + #[cfg(feature = "tags")] + tagged: bool, } impl<'a, W> StructSerializer<'a, W> where W: Write, { + #[cfg(feature = "tags")] + #[inline] + fn serialize_field_inner(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + ser::Serialize, + { + if self.tagged && key == "__cbor_tag_ser_tag" { + assert_eq!(self.idx, 0); + // write the tag as + self.ser.tag = true; + value.serialize(&mut *self.ser)?; + self.ser.tag = false; + } else if self.tagged && key == "__cbor_tag_ser_data" { + assert_eq!(self.idx, 1); + // only write the data, without key + value.serialize(&mut *self.ser)?; + } else { + // regular struct + if self.ser.packed { + self.idx.serialize(&mut *self.ser)?; + } else { + key.serialize(&mut *self.ser)?; + } + value.serialize(&mut *self.ser)?; + } + + self.idx += 1; + Ok(()) + } + + #[cfg(not(feature = "tags"))] #[inline] fn serialize_field_inner(&mut self, key: &'static str, value: &T) -> Result<()> where diff --git a/src/tag.rs b/src/tag.rs new file mode 100644 index 00000000..cfb41552 --- /dev/null +++ b/src/tag.rs @@ -0,0 +1,40 @@ +use serde::ser::{Serialize, SerializeStruct, Serializer}; + +/// Wrapper struct to handle encoding Cbor semantic tags. +#[derive(Deserialize)] +pub struct EncodeCborTag { + __cbor_tag_ser_tag: u64, + __cbor_tag_ser_data: T, +} + +impl EncodeCborTag { + /// Constructs a new `EncodeCborTag`, to wrap your type in a tag. + pub fn new(tag: u64, value: T) -> Self { + EncodeCborTag { + __cbor_tag_ser_tag: tag, + __cbor_tag_ser_data: value, + } + } + + /// Returns the tag. + pub fn tag(&self) -> u64 { + self.__cbor_tag_ser_tag + } + + /// Returns the inner value, consuming the wrapper. + pub fn value(self) -> T { + self.__cbor_tag_ser_data + } +} + +impl Serialize for EncodeCborTag { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("EncodeCborTag", 2)?; + state.serialize_field("__cbor_tag_ser_tag", &self.__cbor_tag_ser_tag)?; + state.serialize_field("__cbor_tag_ser_data", &self.__cbor_tag_ser_data)?; + state.end() + } +} diff --git a/tests/de.rs b/tests/de.rs index 7e3fd9cf..49845a72 100644 --- a/tests/de.rs +++ b/tests/de.rs @@ -220,6 +220,21 @@ mod std_tests { assert_eq!(value.unwrap(), Value::Float(100000.0)); } + #[cfg(feature = "tags")] + #[test] + fn test_self_describing() { + let value: error::Result = + de::from_slice(&[0xd9, 0xd9, 0xf7, 0x66, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72]); + assert_eq!( + value.unwrap(), + Value::Array(vec![ + Value::Integer(55799), + Value::Text("foobar".to_owned()) + ]) + ); + } + + #[cfg(not(feature = "tags"))] #[test] fn test_self_describing() { let value: error::Result = diff --git a/tests/ser.rs b/tests/ser.rs index d374ce2f..277c2daa 100644 --- a/tests/ser.rs +++ b/tests/ser.rs @@ -251,4 +251,69 @@ mod std_tests { assert_eq!(vec, b"\xF9\x51\x50"); assert_eq!(from_slice::(&vec[..]).unwrap(), 42.5f32); } + + #[cfg(feature = "tags")] + #[test] + fn test_tags() { + use serde::de::{Deserialize, Deserializer}; + use serde::Serialize; + + #[derive(Debug, PartialEq)] + struct MyTaggedValue(Vec); + + impl MyTaggedValue { + fn tag() -> u64 { + 9 + } + } + + impl Serialize for MyTaggedValue { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serde_cbor::EncodeCborTag::new(Self::tag(), &self.0).serialize(serializer) + } + } + + impl<'de> Deserialize<'de> for MyTaggedValue { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let wrapper = serde_cbor::EncodeCborTag::deserialize(deserializer)?; + if wrapper.tag() != Self::tag() { + return Err(serde::de::Error::custom(format!( + "Invalid tag: {}, expected {}", + wrapper.tag(), + Self::tag() + ))); + } + Ok(MyTaggedValue(wrapper.value())) + } + } + + // Roundtrip with custom type + + let value = MyTaggedValue(vec![1, 2, 3]); + let bytes = b"\xC9\x83\x01\x02\x03"; + assert_eq!(&to_vec(&value).unwrap()[..], bytes); + let res: MyTaggedValue = serde_cbor::de::from_slice_with_scratch(bytes, &mut []).unwrap(); + println!("{:?}", res); + + assert_eq!(res, value); + + // Deserialize tag into tuple + let res: (u64, (usize, usize, usize)) = + serde_cbor::de::from_slice_with_scratch(bytes, &mut []).unwrap(); + println!("{:?}", res); + + assert_eq!(res, (9, (1, 2, 3))); + + // Deserialize into tuple multiple bytes for the tag + println!("----"); + let value: (u64, (usize, usize, usize)) = + serde_cbor::de::from_slice(&[0xd9, 0xd9, 0xf6, 0x83, 0x01, 0x02, 0x03]).unwrap(); + assert_eq!(value, (55798, (1, 2, 3))); + } } From 912757c2e018eb9414a2d26091efed5db0784ec1 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Wed, 19 Jun 2019 22:10:18 +0200 Subject: [PATCH 2/2] feat: handle tags when going through values --- src/value/mod.rs | 3 +++ src/value/ser.rs | 19 ++++++++++++++++++- tests/ser.rs | 7 +++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/value/mod.rs b/src/value/mod.rs index 3e83c461..040a522e 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -51,6 +51,8 @@ pub enum Value { /// to establish canonical order may be slow and therefore insertion /// and retrieval of values will be slow too. Map(BTreeMap), + /// Semantic Tag + Tag(u64), // The hidden variant allows the enum to be extended // with variants for tags and simple values. #[doc(hidden)] @@ -147,6 +149,7 @@ impl Value { Text(_) => 3, Array(_) => 4, Map(_) => 5, + Tag(_) => 6, __Hidden => unreachable!(), } } diff --git a/src/value/ser.rs b/src/value/ser.rs index e51dea89..930dd7ad 100644 --- a/src/value/ser.rs +++ b/src/value/ser.rs @@ -24,7 +24,24 @@ impl serde::Serialize for Value { Value::Bytes(ref v) => serializer.serialize_bytes(&v), Value::Text(ref v) => serializer.serialize_str(&v), Value::Array(ref v) => v.serialize(serializer), - Value::Map(ref v) => v.serialize(serializer), + Value::Tag(v) => serializer.serialize_u64(v), + Value::Map(ref v) => { + if v.len() == 2 { + use serde::ser::SerializeStruct; + + let tag = v.get(&Value::Text("__cbor_tag_ser_tag".to_string())); + let value = v.get(&Value::Text("__cbor_tag_ser_data".to_string())); + if tag.is_some() && value.is_some() { + if let Some(Value::Integer(tag)) = tag { + let mut s = serializer.serialize_struct("EncodeCborTag", 2)?; + s.serialize_field("__cbor_tag_ser_tag", &Value::Tag(*tag as u64))?; + s.serialize_field("__cbor_tag_ser_data", value.unwrap())?; + return s.end(); + } + } + } + v.serialize(serializer) + } Value::Float(v) => serializer.serialize_f64(v), Value::Bool(v) => serializer.serialize_bool(v), Value::Null => serializer.serialize_unit(), diff --git a/tests/ser.rs b/tests/ser.rs index 277c2daa..60a21d58 100644 --- a/tests/ser.rs +++ b/tests/ser.rs @@ -303,6 +303,13 @@ mod std_tests { assert_eq!(res, value); + // Serialize via `Value` + + let value = MyTaggedValue(vec![1, 2, 3]); + let bytes = b"\xC9\x83\x01\x02\x03"; + let encoded = to_vec(&serde_cbor::value::to_value(value).unwrap()).unwrap(); + assert_eq!(&encoded[..], bytes); + // Deserialize tag into tuple let res: (u64, (usize, usize, usize)) = serde_cbor::de::from_slice_with_scratch(bytes, &mut []).unwrap();