From 05eabb4526ecb257e60f85b3ae4739e80ea1e6b5 Mon Sep 17 00:00:00 2001 From: Volker Mische Date: Wed, 25 Sep 2019 15:18:52 +0200 Subject: [PATCH] Add tags support behind feature flag Tags are stored as Newtype Structs with the special name `_TagStruct`. The tags are represented as a struct tuple consisting of the tag identifier and a value. Example: let tag = Value::Tag(123, Box::new(Value::Text("some value".to_string()))); let tag_encoded = to_vec(&tag).unwrap(); This implementation is heavily based on the msgpack-rust library, when support for extensions were added [1][2]. They have the same problem as CBOR has, that such an extension/tagging type is not directly supported by Serde, hence a workaround is needed. For me it makes sense to use a similarly working workaround across different libraries. This commit also contains an example on how to use custom types as tags. It can be run via: cargo run --example tags --features tags [1]: https://github.com/3Hren/msgpack-rust/commit/a34ab8fcca9f1133ec1c4fc446a59b71acf5613e [2]: https://github.com/3Hren/msgpack-rust/pull/216 --- .travis.yml | 1 + Cargo.toml | 2 + examples/tags.rs | 125 +++++++++++ src/de.rs | 74 +++++-- src/lib.rs | 29 ++- src/ser.rs | 31 +++ src/tags.rs | 558 +++++++++++++++++++++++++++++++++++++++++++++++ src/value/de.rs | 33 +++ src/value/mod.rs | 5 + src/value/ser.rs | 7 + tests/de.rs | 12 + tests/tags.rs | 253 +++++++++++++++++++++ 12 files changed, 1111 insertions(+), 19 deletions(-) create mode 100644 examples/tags.rs create mode 100644 src/tags.rs create mode 100644 tests/tags.rs diff --git a/.travis.yml b/.travis.yml index fe6a40c8..b85e867f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,3 +21,4 @@ script: - [[ $TRAVIS_RUST_VERSION != "1.31.0" ]] && cargo build --no-default-features --features alloc - cargo build --features unsealed_read_write # The crate should still build when the unsealed_read_write feature is enabled. - cargo build --no-default-features --features unsealed_read_write # The crate should still build when the unsealed_read_write feature is enabled and std disabled. + - cargo test --features tags # Run tags tests diff --git a/Cargo.toml b/Cargo.toml index 18088d53..3d1108c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ serde = { version = "1.0.14", default-features = false } [dev-dependencies] serde_derive = { version = "1.0.14", default-features = false } +serde_bytes = "0.11.2" [features] default = ["std"] @@ -31,3 +32,4 @@ default = ["std"] alloc = ["serde/alloc"] std = ["serde/std" ] unsealed_read_write = [] +tags = [] diff --git a/examples/tags.rs b/examples/tags.rs new file mode 100644 index 00000000..3d944729 --- /dev/null +++ b/examples/tags.rs @@ -0,0 +1,125 @@ +fn main() { + #[cfg(feature = "tags")] + tags_example::main(); + #[cfg(not(feature = "tags"))] + println!("Run this example with the `--feature tags` flag."); +} + +#[cfg(feature = "tags")] +mod tags_example { + use serde::de::{self, Unexpected}; + use serde::ser; + use serde_derive::{Deserialize, Serialize}; + + use serde_bytes; + + use serde_cbor::value::Value; + use serde_cbor::{from_slice, to_vec}; + + use std::fmt; + + #[derive(Debug, PartialEq)] + struct Cid(Vec); + + impl ser::Serialize for Cid { + fn serialize(&self, s: S) -> Result + where + S: ser::Serializer, + { + let tag = 42u64; + let value = serde_bytes::ByteBuf::from(&self.0[..]); + s.serialize_newtype_struct(serde_cbor::CBOR_TAG_STRUCT_NAME, &(tag, value)) + } + } + + struct CidVisitor; + + impl<'de> de::Visitor<'de> for CidVisitor { + type Value = Cid; + + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "a sequence of tag and value") + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_tuple(2, self) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let tag: u64 = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let value: Value = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + match (tag, value) { + // Only return the value if tag and value type match + (42, Value::Bytes(bytes)) => Ok(Cid(bytes)), + _ => { + let error = format!("tag: {:?}", tag); + let unexpected = Unexpected::Other(&error); + Err(de::Error::invalid_value(unexpected, &self)) + } + } + } + } + + impl<'de> de::Deserialize<'de> for Cid { + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + let visitor = CidVisitor; + deserializer.deserialize_newtype_struct(serde_cbor::CBOR_TAG_STRUCT_NAME, visitor) + } + } + + #[derive(Debug, PartialEq, Deserialize, Serialize)] + struct MyStruct { + cid: Cid, + data: bool, + } + + pub fn main() { + // Serialize any CBOR tag you like, the tag identifier is an u64 and the value is any of + // the CBOR values available. + let tag = Value::Tag(123, Box::new(Value::Text("some value".to_string()))); + println!("Tag: {:?}", tag); + let tag_encoded = to_vec(&tag).unwrap(); + println!("Encoded tag: {:x?}", tag_encoded); + + // You can also have your own custom tags implemented, that don't even use the CBOR `Value` + // type. In this example we encode a vector of integers as byte string with tag 42. + let cid = Cid(vec![1, 2, 3]); + println!("CID: {:?}", cid); + let cid_encoded = to_vec(&cid).unwrap(); + println!("Encoded CID: {:x?}", cid_encoded); + + // You can either decode it again as your custom object... + let cid_decoded_as_cid: Cid = from_slice(&cid_encoded).unwrap(); + println!("Decoded CID as CID: {:?}", cid_decoded_as_cid); + // ...or as a generic CBOR Value, which will then transform it into a `Tag()`. + let cid_decoded_as_value: Value = from_slice(&cid_encoded).unwrap(); + println!("Decoded CID as Value: {:?}", cid_decoded_as_value); + + // Your custom object also works if it is nested in a truct + let mystruct = MyStruct { cid, data: true }; + println!("Custom struct: {:?}", mystruct); + let mystruct_encoded = to_vec(&mystruct).unwrap(); + println!("Encoded custom struct: {:?}", mystruct_encoded); + let mystruct_decoded_as_mystruct: MyStruct = from_slice(&mystruct_encoded).unwrap(); + println!("Decoded custom struct: {:?}", mystruct_decoded_as_mystruct); + let mystruct_decoded_as_value: Value = from_slice(&mystruct_encoded).unwrap(); + println!( + "Decoded custom struct as Value: {:?}", + mystruct_decoded_as_value + ); + } +} diff --git a/src/de.rs b/src/de.rs index 69cb4564..f0bf2ad1 100644 --- a/src/de.rs +++ b/src/de.rs @@ -21,6 +21,24 @@ use crate::read::Offset; #[cfg(any(feature = "std", feature = "alloc"))] pub use crate::read::SliceRead; pub use crate::read::{MutSliceRead, Read, SliceReadFixed}; +#[cfg(feature = "tags")] +use crate::tags::TagDeserializer; + +/// CBOR tags can be stored with different bit widths +#[derive(Clone, Copy, Debug)] +pub enum TagType { + /// CBOR tags < 24 are stored inline with the tag identifier + Inline(u8), + /// 1 byte CBOR tag + U8, + /// 2 bytes CBOR tag + U16, + /// 4 bytes CBOR tag + U32, + /// 8 bytes CBOR tag + U64, +} + /// Decodes a value from CBOR data in a slice. /// /// # Examples @@ -558,7 +576,7 @@ where // Don't warn about the `unreachable!` in case // exhaustive integer pattern matching is enabled. #[allow(unreachable_patterns)] - fn parse_value(&mut self, visitor: V) -> Result + pub(super) fn parse_value(&mut self, visitor: V) -> Result where V: de::Visitor<'de>, { @@ -704,23 +722,11 @@ where 0xbf => self.parse_indefinite_map(visitor), // Major type 6: optional semantic tagging of other major types - 0xc0..=0xd7 => self.parse_value(visitor), - 0xd8 => { - self.parse_u8()?; - self.parse_value(visitor) - } - 0xd9 => { - self.parse_u16()?; - self.parse_value(visitor) - } - 0xda => { - self.parse_u32()?; - self.parse_value(visitor) - } - 0xdb => { - self.parse_u64()?; - self.parse_value(visitor) - } + val @ 0xc0..=0xd7 => self.parse_tag(TagType::Inline(val), visitor), + 0xd8 => self.parse_tag(TagType::U8, visitor), + 0xd9 => self.parse_tag(TagType::U16, visitor), + 0xda => self.parse_tag(TagType::U32, visitor), + 0xdb => self.parse_tag(TagType::U64, visitor), 0xdc..=0xdf => Err(self.error(ErrorCode::UnassignedCode)), // Major type 7: floating-point numbers and other simple data types that need no content @@ -748,6 +754,38 @@ where _ => unreachable!(), } } + + /// Return the parsed tag as u64 + pub(super) fn parse_tag_by_type(&mut self, tag_type: TagType) -> Result { + let tag = match tag_type { + TagType::U8 => self.parse_u8()? as u64, + TagType::U16 => self.parse_u16()? as u64, + TagType::U32 => self.parse_u32()? as u64, + TagType::U64 => self.parse_u64()? as u64, + TagType::Inline(tag) => (tag - 0xc0) as u64, + }; + Ok(tag) + } + + #[cfg(feature = "tags")] + fn parse_tag(&mut self, tag_type: TagType, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let tag_de = TagDeserializer::new(self, tag_type); + visitor.visit_newtype_struct(tag_de) + } + + #[cfg(not(feature = "tags"))] + fn parse_tag(&mut self, tag_type: TagType, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Skip the tag with parsing it without producing any output + let _tag = self.parse_tag_by_type(tag_type)?; + // And parse the value only + self.parse_value(visitor) + } } impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer diff --git a/src/lib.rs b/src/lib.rs index 749cdeb9..012b6160 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -241,7 +241,8 @@ //! //! * [Tags] are ignored during deserialization and can't be emitted during //! serialization. This is because Serde has no concept of tagged -//! values. See: [#3] +//! values. See: [#3]. Support for tags can be enabled with the +//! `tags` feature flag. //! * Unknown [simple values] cause an `UnassignedCode` error. //! The simple values *False* and *True* are recognized and parsed as bool. //! *Null* and *Undefined* are both deserialized as *unit*. @@ -276,6 +277,9 @@ mod write; #[cfg(feature = "std")] pub mod value; +#[cfg(feature = "tags")] +pub mod tags; + // Re-export the [items recommended by serde](https://serde.rs/conventions.html). #[doc(inline)] pub use crate::de::{Deserializer, StreamDeserializer}; @@ -308,3 +312,26 @@ pub use crate::ser::to_writer; #[cfg(feature = "std")] #[doc(inline)] pub use crate::value::Value; + +/// Name of Serde newtype struct to Represent CBOR tags +/// CBOR Tag: Tag(tag, value) +/// Serde data model: _TagStruct((tag, binary)) +/// Example Serde impl for custom type: +/// +/// ``` +/// use serde_cbor::value::Value; +/// use serde_cbor::{from_slice, to_vec}; +/// use serde_derive::{Deserialize, Serialize}; +/// +/// #[derive(Debug, PartialEq, Serialize, Deserialize)] +/// #[serde(rename = "_TagStruct")] +/// struct Cid((u64, Value)); +/// +/// let tag = Cid((42, Value::Bytes(vec![1, 2, 3]))); +/// let tag_encoded = to_vec(&tag).unwrap(); +/// assert_eq!(tag_encoded, [0xd8, 0x2a, 0x43, 0x01, 0x02, 0x03]); +/// let tag_decoded = from_slice::(&tag_encoded).unwrap(); +/// assert_eq!(tag_decoded, Value::Tag(42, Box::new(Value::Bytes(vec![1, 2, 3])))); +/// ``` +#[cfg(feature = "tags")] +pub const CBOR_TAG_STRUCT_NAME: &'static str = "_TagStruct"; diff --git a/src/ser.rs b/src/ser.rs index 13d6f07b..a92a6645 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -14,6 +14,14 @@ use serde::ser::{self, Serialize}; #[cfg(feature = "std")] use std::io; +#[cfg(feature = "tags")] +use crate::tags::TagStructSerializer; +#[cfg(feature = "tags")] +use crate::CBOR_TAG_STRUCT_NAME; + +#[cfg(feature = "tags")] +const MAJOR_TYPE_TAG: u8 = 6; + /// Serializes a value to a vector. #[cfg(any(feature = "std", feature = "alloc"))] pub fn to_vec(value: &T) -> Result> @@ -180,6 +188,12 @@ where } } + #[cfg(feature = "tags")] + #[inline] + pub(super) fn write_tag(&mut self, tag: u64) -> Result<()> { + self.write_u64(MAJOR_TYPE_TAG, tag) + } + #[inline] fn serialize_collection<'a>( &'a mut self, @@ -399,6 +413,23 @@ where } } + #[cfg(feature = "tags")] + #[inline] + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result<()> + where + T: ?Sized + ser::Serialize, + { + if name == CBOR_TAG_STRUCT_NAME { + // The value is a struct that gets forwarded to the `TagStructSerializer` + let mut tag_ser = TagStructSerializer::new(self); + value.serialize(&mut tag_ser)?; + return Ok(()); + } + + value.serialize(self) + } + + #[cfg(not(feature = "tags"))] #[inline] fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> where diff --git a/src/tags.rs b/src/tags.rs new file mode 100644 index 00000000..ec2c546f --- /dev/null +++ b/src/tags.rs @@ -0,0 +1,558 @@ +//! Serialize and Deserialize CBOR tags. +use crate::de::{Deserializer, TagType}; +use crate::error::{Error, Result}; +use crate::read::Read; +use crate::ser::Serializer; +use crate::write::Write; + +use serde::de; +use serde::ser::{self, Serialize}; + +/// This is simply for serializing the tag itself (not the data) +#[derive(Debug)] +pub struct TagSerializer<'a, W> { + ser: &'a mut Serializer, +} + +impl<'a, W> TagSerializer<'a, W> +where + W: Write, +{ + fn new(ser: &'a mut Serializer) -> Self { + Self { ser } + } +} + +impl<'a, W> ser::Serializer for &mut TagSerializer<'a, W> +where + W: Write, +{ + type Ok = (); + type Error = Error; + + type SerializeSeq = ser::Impossible<(), Error>; + type SerializeTuple = ser::Impossible<(), Error>; + type SerializeTupleStruct = ser::Impossible<(), Error>; + type SerializeTupleVariant = ser::Impossible<(), Error>; + type SerializeMap = ser::Impossible<(), Error>; + type SerializeStruct = ser::Impossible<(), Error>; + type SerializeStructVariant = ser::Impossible<(), Error>; + + #[inline] + fn serialize_bytes(self, _value: &[u8]) -> Result<()> { + Err(Error::message("expected an u64, received bytes")) + } + + #[inline] + fn serialize_bool(self, _value: bool) -> Result<()> { + Err(Error::message("expected an u64, received bool")) + } + + #[inline] + fn serialize_i8(self, _value: i8) -> Result<()> { + Err(Error::message("expected an u64, received i8")) + } + + #[inline] + fn serialize_i16(self, _value: i16) -> Result<()> { + Err(Error::message("expected an u64, received i16")) + } + + #[inline] + fn serialize_i32(self, _value: i32) -> Result<()> { + Err(Error::message("expected an u64, received i32")) + } + + #[inline] + fn serialize_i64(self, _value: i64) -> Result<()> { + Err(Error::message("expected an u64, received i64")) + } + + #[inline] + fn serialize_u8(self, _value: u8) -> Result<()> { + Err(Error::message("expected an u64, received u8")) + } + + #[inline] + fn serialize_u16(self, _value: u16) -> Result<()> { + Err(Error::message("expected an u64, received u16")) + } + + #[inline] + fn serialize_u32(self, _value: u32) -> Result<()> { + Err(Error::message("expected an u64, received u32")) + } + + // The `Tag` definition is with a u64, hence always only this case is hit. `write_64` will + // make sure that the actual value is written with the smallest representation possible + #[inline] + fn serialize_u64(self, value: u64) -> Result<()> { + self.ser.write_tag(value) + } + + #[inline] + fn serialize_f32(self, _value: f32) -> Result<()> { + Err(Error::message("expected an u64, received f32")) + } + + #[inline] + fn serialize_f64(self, _value: f64) -> Result<()> { + Err(Error::message("expected an u64, received f64")) + } + + #[inline] + fn serialize_char(self, _value: char) -> Result<()> { + Err(Error::message("expected an u64, received char")) + } + + #[inline] + fn serialize_str(self, _value: &str) -> Result<()> { + Err(Error::message("expected an u64, received str")) + } + + #[inline] + fn serialize_unit(self) -> Result<()> { + Err(Error::message("expected an u64, received unit")) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(Error::message("expected an u64, received unit_struct")) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + ) -> Result<()> { + Err(Error::message("expected an u64, received unit_variant")) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<()> + where + T: Serialize, + { + Err(Error::message("expected an u64, received newtype_struct")) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: Serialize, + { + Err(Error::message("expected an u64, received newtype_variant")) + } + + #[inline] + fn serialize_none(self) -> Result<()> { + Err(Error::message("expected an u64, received none")) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result<()> + where + T: Serialize, + { + Err(Error::message("expected an u64, received some")) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + Err(Error::message("expected an u64, received seq")) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result { + Err(Error::message("expected an u64, received tuple")) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Error::message("expected an u64, received tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::message("expected an u64, received tuple_variant")) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + Err(Error::message("expected an u64, received map")) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Error::message("expected an u64, received struct")) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::message("expected an u64, received struct_variant")) + } +} + +/// Represents CBOR serialization implementation for tags +#[derive(Debug)] +pub struct TagStructSerializer<'a, W> { + // True if the tag (the first element of the tuple) was already read + tag_read: bool, + // The serializer for the actual tag + tag_tag_ser: TagSerializer<'a, W>, +} + +impl<'a, W> TagStructSerializer<'a, W> +where + W: Write, +{ + /// Creates a new serializer for CBOR tags. + pub fn new(ser: &'a mut Serializer) -> Self { + Self { + tag_read: false, + tag_tag_ser: TagSerializer::new(ser), + } + } +} + +impl<'a, W> ser::Serializer for &mut TagStructSerializer<'a, W> +where + W: Write, +{ + type Ok = (); + type Error = Error; + + type SerializeSeq = serde::ser::Impossible<(), Error>; + type SerializeTuple = Self; + type SerializeTupleStruct = serde::ser::Impossible<(), Error>; + type SerializeTupleVariant = serde::ser::Impossible<(), Error>; + type SerializeMap = serde::ser::Impossible<(), Error>; + type SerializeStruct = serde::ser::Impossible<(), Error>; + type SerializeStructVariant = serde::ser::Impossible<(), Error>; + + #[inline] + fn serialize_bytes(self, _val: &[u8]) -> Result<()> { + Err(Error::message("expected tuple, received bytes")) + } + + #[inline] + fn serialize_bool(self, _val: bool) -> Result<()> { + Err(Error::message("expected tuple, received bool")) + } + + #[inline] + fn serialize_i8(self, _value: i8) -> Result<()> { + Err(Error::message("expected tuple, received i8")) + } + + #[inline] + fn serialize_i16(self, _val: i16) -> Result<()> { + Err(Error::message("expected tuple, received i16")) + } + + #[inline] + fn serialize_i32(self, _val: i32) -> Result<()> { + Err(Error::message("expected tuple, received i32")) + } + + #[inline] + fn serialize_i64(self, _val: i64) -> Result<()> { + Err(Error::message("expected tuple, received i64")) + } + + #[inline] + fn serialize_u8(self, _val: u8) -> Result<()> { + Err(Error::message("expected tuple, received u8")) + } + + #[inline] + fn serialize_u16(self, _val: u16) -> Result<()> { + Err(Error::message("expected tuple, received u16")) + } + + #[inline] + fn serialize_u32(self, _val: u32) -> Result<()> { + Err(Error::message("expected tuple, received u32")) + } + + #[inline] + fn serialize_u64(self, _val: u64) -> Result<()> { + Err(Error::message("expected tuple, received u64")) + } + + #[inline] + fn serialize_f32(self, _val: f32) -> Result<()> { + Err(Error::message("expected tuple, received f32")) + } + + #[inline] + fn serialize_f64(self, _val: f64) -> Result<()> { + Err(Error::message("expected tuple, received f64")) + } + + #[inline] + fn serialize_char(self, _val: char) -> Result<()> { + Err(Error::message("expected tuple, received char")) + } + + #[inline] + fn serialize_str(self, _val: &str) -> Result<()> { + Err(Error::message("expected tuple, received str")) + } + + #[inline] + fn serialize_unit(self) -> Result<()> { + Err(Error::message("expected tuple, received unit")) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(Error::message("expected tuple, received unit_struct")) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + ) -> Result<()> { + Err(Error::message("expected tuple, received unit_variant")) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<()> + where + T: Serialize, + { + Err(Error::message("expected tuple, received newtype_struct")) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: Serialize, + { + Err(Error::message("expected tuple, received newtype_variant")) + } + + #[inline] + fn serialize_none(self) -> Result<()> { + Err(Error::message("expected tuple, received none")) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result<()> + where + T: Serialize, + { + Err(Error::message("expected tuple, received some")) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + Err(Error::message("expected tuple, received seq")) + } + + #[inline] + fn serialize_tuple(self, len: usize) -> Result { + if len == 2 { + Ok(self) + } else { + Err(Error::message(format!( + "expected tuple with two elements, received tuple with {} elements", + len + ))) + } + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Error::message("expected tuple, received tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::message("expected tuple, received tuple_variant")) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + Err(Error::message("expected tuple, received map")) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Error::message("expected tuple, received struct")) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::message("expected tuple, received struct_variant")) + } +} + +impl<'a, W> ser::SerializeTuple for &mut TagStructSerializer<'a, W> +where + W: Write, +{ + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> { + // Serialize the value with the default serializer + if self.tag_read { + value.serialize(&mut *self.tag_tag_ser.ser) + } + // Serialize the tag itself + else { + self.tag_read = true; + value.serialize(&mut self.tag_tag_ser) + } + } + + fn end(self) -> Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +enum TagDeserializerState { + New, + ReadTag, + ReadData, +} + +/// Deserialize a CBOR tag and its value +#[derive(Debug)] +pub struct TagDeserializer<'a, R> { + de: &'a mut Deserializer, + tag_type: TagType, + state: TagDeserializerState, +} + +impl<'de, 'a, R> TagDeserializer<'a, R> +where + R: Read<'de> + 'a, +{ + /// Creates a new TagDeserializer. + pub fn new(de: &'a mut Deserializer, tag_type: TagType) -> Self { + TagDeserializer { + de, + tag_type, + state: TagDeserializerState::New, + } + } +} + +impl<'de, 'a, R> de::Deserializer<'de> for TagDeserializer<'a, R> +where + R: Read<'de> + 'a, +{ + type Error = Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_seq(self) + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + struct identifier tuple enum ignored_any tuple_struct + } +} + +impl<'de, 'a, R> de::SeqAccess<'de> for TagDeserializer<'a, R> +where + R: Read<'de> + 'a, +{ + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + Ok(Some(seed.deserialize(self)?)) + } +} + +/// Deserializer for Tag SeqAccess +impl<'de, 'a, R> de::Deserializer<'de> for &mut TagDeserializer<'a, R> +where + R: Read<'de> + 'a, +{ + type Error = Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.state { + TagDeserializerState::New => { + let tag = self.de.parse_tag_by_type(self.tag_type)?; + self.state = TagDeserializerState::ReadTag; + visitor.visit_u64(tag) + } + TagDeserializerState::ReadTag => { + self.state = TagDeserializerState::ReadData; + self.de.parse_value(visitor) + } + TagDeserializerState::ReadData => unreachable!(), + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + tuple_struct struct identifier tuple enum ignored_any + } +} diff --git a/src/value/de.rs b/src/value/de.rs index 2905f3b7..6790da1e 100644 --- a/src/value/de.rs +++ b/src/value/de.rs @@ -134,6 +134,39 @@ impl<'de> de::Deserialize<'de> for Value { { Ok(Value::Float(v)) } + + #[cfg(feature = "tags")] + #[inline] + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + struct TagValueVisitor; + impl<'de> serde::de::Visitor<'de> for TagValueVisitor { + type Value = Value; + + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("any valid CBOR tag") + } + + #[inline] + fn visit_seq(self, mut seq: V) -> Result + where + V: de::SeqAccess<'de>, + { + let tag = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let value = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + Ok(Value::Tag(tag, Box::new(value))) + } + } + + deserializer.deserialize_tuple(2, TagValueVisitor) + } } deserializer.deserialize_any(ValueVisitor) diff --git a/src/value/mod.rs b/src/value/mod.rs index 3e83c461..bea83d0e 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -51,6 +51,9 @@ pub enum Value { /// to establish canonical order may be slow and therefore insertion /// and retrieval of values will be slow too. Map(BTreeMap), + /// CBOR Tags + #[cfg(feature = "tags")] + Tag(u64, Box), // The hidden variant allows the enum to be extended // with variants for tags and simple values. #[doc(hidden)] @@ -147,6 +150,8 @@ impl Value { Text(_) => 3, Array(_) => 4, Map(_) => 5, + #[cfg(feature = "tags")] + Tag(_, _) => 6, __Hidden => unreachable!(), } } diff --git a/src/value/ser.rs b/src/value/ser.rs index e51dea89..df53f566 100644 --- a/src/value/ser.rs +++ b/src/value/ser.rs @@ -13,6 +13,9 @@ use serde::{self, Serialize}; use crate::value::Value; +#[cfg(feature = "tags")] +use crate::CBOR_TAG_STRUCT_NAME; + impl serde::Serialize for Value { #[inline] fn serialize(&self, serializer: S) -> Result @@ -28,6 +31,10 @@ impl serde::Serialize for Value { Value::Float(v) => serializer.serialize_f64(v), Value::Bool(v) => serializer.serialize_bool(v), Value::Null => serializer.serialize_unit(), + #[cfg(feature = "tags")] + Value::Tag(ref tag, ref v) => { + serializer.serialize_newtype_struct(CBOR_TAG_STRUCT_NAME, &(tag, v)) + } Value::__Hidden => unreachable!(), } } diff --git a/tests/de.rs b/tests/de.rs index 7e3fd9cf..2712501b 100644 --- a/tests/de.rs +++ b/tests/de.rs @@ -220,6 +220,18 @@ mod std_tests { assert_eq!(value.unwrap(), Value::Float(100000.0)); } + #[cfg(feature = "tags")] + #[test] + fn test_self_describing() { + let value: error::Result = + de::from_slice(&[0xd9, 0xd9, 0xf7, 0x66, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72]); + assert_eq!( + value.unwrap(), + Value::Tag(55799, Box::new(Value::Text("foobar".to_owned()))) + ); + } + + #[cfg(not(feature = "tags"))] #[test] fn test_self_describing() { let value: error::Result = diff --git a/tests/tags.rs b/tests/tags.rs new file mode 100644 index 00000000..1c15d5a8 --- /dev/null +++ b/tests/tags.rs @@ -0,0 +1,253 @@ +#[cfg(feature = "tags")] +mod tags_tests { + use serde_bytes; + use serde_cbor::value::Value; + use serde_cbor::{from_slice, to_vec}; + use serde_derive::{Deserialize, Serialize}; + + fn serialize_and_compare(value: T, expected: &[u8]) { + assert_eq!(to_vec(&value).unwrap(), expected); + } + + #[test] + fn test_tags_inline() { + let value = Value::Tag(1, Box::new(Value::Bool(true))); + serialize_and_compare(value, &[0xc1, 0xf5]); + } + + #[test] + fn test_tags_u8() { + let value = Value::Tag(50, Box::new(Value::Bool(true))); + serialize_and_compare(value, &[0xd8, 0x32, 0xf5]); + } + + #[test] + fn test_tags_u16() { + let value = Value::Tag(600, Box::new(Value::Bool(true))); + serialize_and_compare(value, &[0xd9, 0x02, 0x58, 0xf5]); + } + + #[test] + fn test_tags_u32() { + let value = Value::Tag(70_000, Box::new(Value::Bool(true))); + serialize_and_compare(value, &[0xda, 0x00, 0x01, 0x11, 0x70, 0xf5]); + } + + #[test] + fn test_tags_u64() { + let value = Value::Tag(8_000_000_000, Box::new(Value::Bool(true))); + serialize_and_compare( + value, + &[0xdb, 0x00, 0x00, 0x00, 0x01, 0xDC, 0xD6, 0x50, 0x00, 0xf5], + ); + } + + #[test] + fn test_tags_null() { + let value = Value::Tag(40, Box::new(Value::Null)); + serialize_and_compare(value, &[0xd8, 0x28, 0xf6]); + } + + #[test] + fn test_tags_bool() { + let value = Value::Tag(40, Box::new(Value::Bool(false))); + serialize_and_compare(value, &[0xd8, 0x28, 0xf4]); + } + + #[test] + fn test_tags_integer() { + let value = Value::Tag(40, Box::new(Value::Integer(12345))); + serialize_and_compare(value, &[0xd8, 0x28, 0x19, 0x30, 0x39]); + } + + #[test] + fn test_tags_float() { + let value = Value::Tag(40, Box::new(Value::Float(-5.5))); + serialize_and_compare(value, &[0xd8, 0x28, 0xF9, 0xC5, 0x80]); + } + + #[test] + fn test_tags_bytes() { + let value = Value::Tag(40, Box::new(Value::Bytes(vec![3, 4, 5]))); + serialize_and_compare(value, &[0xd8, 0x28, 0x43, 0x03, 0x04, 0x05]); + } + + #[test] + fn test_tags_text() { + let value = Value::Tag(40, Box::new(Value::Text("yay".to_string()))); + serialize_and_compare(value, &[0xd8, 0x28, 0x63, 0x79, 0x61, 0x79]); + } + + #[test] + fn test_tags_array() { + let value = Value::Tag( + 40, + Box::new(Value::Array(vec![Value::Bool(true), Value::Integer(7)])), + ); + serialize_and_compare(value, &[0xd8, 0x28, 0x82, 0xf5, 0x07]); + } + + #[test] + fn test_tags_map() { + let mut map = std::collections::BTreeMap::new(); + map.insert("foo", 1); + map.insert("bar", 2); + + let value = Value::Tag(40, Box::new(serde_cbor::value::to_value(map).unwrap())); + serialize_and_compare( + value, + &[ + 0xd8, 0x28, 0xa2, 0x63, 0x62, 0x61, 0x72, 0x02, 0x63, 0x66, 0x6f, 0x6f, 0x01, + ], + ); + } + + #[test] + fn test_tags_tag() { + let value = Value::Tag(40, Box::new(Value::Tag(54321, Box::new(Value::Null)))); + serialize_and_compare(value, &[0xd8, 0x28, 0xd9, 0xd4, 0x31, 0xf6]); + } + + #[test] + fn test_tags_derive_struct() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(rename = "_TagStruct")] + struct MyType((u64, Value)); + + let value = MyType((42, Value::Bytes(vec![1, 2, 3]))); + serialize_and_compare(value, &[0xd8, 0x2a, 0x43, 0x01, 0x02, 0x03]); + } + + #[test] + fn test_tag_decode() { + let tag_encoded = [0xd8, 0x2a, 0x43, 0x01, 0x02, 0x03]; + let tag_decoded = serde_cbor::de::from_slice::(&tag_encoded).unwrap(); + assert_eq!( + tag_decoded, + Value::Tag(42, Box::new(Value::Bytes(vec![1, 2, 3]))) + ); + } + + #[test] + fn test_tags_roundtrip() { + let tag_value = Value::Tag(42, Box::new(Value::Bytes(vec![1, 2, 3]))); + let tag_encoded = to_vec(&tag_value).unwrap(); + assert_eq!(tag_encoded, [0xd8, 0x2a, 0x43, 0x01, 0x02, 0x03]); + + let tag_decoded = from_slice::(&tag_encoded).unwrap(); + assert_eq!(tag_decoded, tag_value); + } + + #[test] + fn test_tags_custom_type() { + #[derive(Debug, PartialEq)] + struct Cid(Vec); + + impl serde::Serialize for Cid { + fn serialize(&self, s: S) -> Result + where + S: serde::ser::Serializer, + { + let tag = 42u64; + let value = serde_bytes::ByteBuf::from(&self.0[..]); + s.serialize_newtype_struct(serde_cbor::CBOR_TAG_STRUCT_NAME, &(tag, value)) + } + } + + struct CidVisitor; + + impl<'de> serde::de::Visitor<'de> for CidVisitor { + type Value = Cid; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(fmt, "a sequence of tag and value") + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_tuple(2, self) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + // First element of the tuple is the tag + let tag: u64 = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + // Second element of the tuple is the value + let value: Value = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + + match (tag, value) { + // Only return the value if tag and value type match + (42, Value::Bytes(bytes)) => Ok(Cid(bytes)), + _ => { + let error = format!("tag: {:?}", tag); + let unexpected = serde::de::Unexpected::Other(&error); + Err(serde::de::Error::invalid_value(unexpected, &self)) + } + } + } + } + + impl<'de> serde::de::Deserialize<'de> for Cid { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let visitor = CidVisitor; + deserializer.deserialize_newtype_struct(serde_cbor::CBOR_TAG_STRUCT_NAME, visitor) + } + } + + #[derive(Debug, PartialEq, Deserialize, Serialize)] + struct MyStruct { + cid: Cid, + data: bool, + } + + // Tests with just the custom type + + let cid = Cid(vec![1, 2, 3]); + let cid_encoded = to_vec(&cid).unwrap(); + assert_eq!(cid_encoded, [0xd8, 0x2a, 0x43, 0x01, 0x02, 0x03]); + + let cid_decoded_as_cid: Cid = from_slice(&cid_encoded).unwrap(); + assert_eq!(cid_decoded_as_cid, cid); + + let cid_decoded_as_value: Value = from_slice(&cid_encoded).unwrap(); + assert_eq!( + cid_decoded_as_value, + Value::Tag(42, Box::new(Value::Bytes(vec![1, 2, 3]))) + ); + + // Tests with the Type nested in a struct + + let mystruct = MyStruct { cid, data: true }; + let mystruct_encoded = to_vec(&mystruct).unwrap(); + assert_eq!( + mystruct_encoded, + [ + 0xa2, 0x63, 0x63, 0x69, 0x64, 0xd8, 0x2a, 0x43, 0x1, 0x2, 0x3, 0x64, 0x64, 0x61, + 0x74, 0x61, 0xf5 + ] + ); + + let mystruct_decoded_as_mystruct: MyStruct = from_slice(&mystruct_encoded).unwrap(); + assert_eq!(mystruct_decoded_as_mystruct, mystruct); + + let mystruct_decoded_as_value: Value = from_slice(&mystruct_encoded).unwrap(); + let mut expected_map = std::collections::BTreeMap::new(); + expected_map.insert( + Value::Text("cid".to_string()), + Value::Tag(42, Box::new(Value::Bytes(vec![1, 2, 3]))), + ); + expected_map.insert(Value::Text("data".to_string()), Value::Bool(true)); + assert_eq!(mystruct_decoded_as_value, Value::Map(expected_map)); + } +}