From da543c347bdd2e97f283b0f9d20c584e94f8f905 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Fri, 8 Sep 2023 18:18:29 -0500 Subject: [PATCH 1/8] Initial impl --- src/{de.rs => de/from_bytes.rs} | 135 +++------ src/de/from_reader.rs | 475 ++++++++++++++++++++++++++++++++ src/de/mod.rs | 107 +++++++ 3 files changed, 613 insertions(+), 104 deletions(-) rename src/{de.rs => de/from_bytes.rs} (81%) create mode 100644 src/de/from_reader.rs create mode 100644 src/de/mod.rs diff --git a/src/de.rs b/src/de/from_bytes.rs similarity index 81% rename from src/de.rs rename to src/de/from_bytes.rs index 9240cc4..133ef07 100644 --- a/src/de.rs +++ b/src/de/from_bytes.rs @@ -1,9 +1,7 @@ -// Copyright (c) The Diem Core Contributors -// SPDX-License-Identifier: Apache-2.0 - use crate::error::{Error, Result}; use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, Visitor}; -use std::convert::TryFrom; + +use super::BcsDeserializer; /// Deserializes a `&[u8]` into a type. /// @@ -53,6 +51,29 @@ where deserializer.end().map(move |_| t) } +impl<'de> BcsDeserializer for Deserializer<'de> { + fn next(&mut self) -> Result { + let byte = self.peek()?; + self.input = &self.input[1..]; + Ok(byte) + } + + fn max_remaining_depth(&mut self) -> usize { + self.max_remaining_depth + } + + fn max_remaining_depth_mut(&mut self) -> &mut usize { + &mut self.max_remaining_depth + } + + fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { + for byte in slice { + *byte = self.next()?; + } + Ok(()) + } +} + /// Deserialization implementation for BCS struct Deserializer<'de> { input: &'de [u8], @@ -86,111 +107,17 @@ impl<'de> Deserializer<'de> { self.input.first().copied().ok_or(Error::Eof) } - fn next(&mut self) -> Result { - let byte = self.peek()?; - self.input = &self.input[1..]; - Ok(byte) - } - - fn parse_bool(&mut self) -> Result { - let byte = self.next()?; - - match byte { - 0 => Ok(false), - 1 => Ok(true), - _ => Err(Error::ExpectedBoolean), - } - } - - fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { - for byte in slice { - *byte = self.next()?; - } - Ok(()) - } - - fn parse_u8(&mut self) -> Result { - self.next() - } - - fn parse_u16(&mut self) -> Result { - let mut le_bytes = [0; 2]; - self.fill_slice(&mut le_bytes)?; - Ok(u16::from_le_bytes(le_bytes)) - } - - fn parse_u32(&mut self) -> Result { - let mut le_bytes = [0; 4]; - self.fill_slice(&mut le_bytes)?; - Ok(u32::from_le_bytes(le_bytes)) - } - - fn parse_u64(&mut self) -> Result { - let mut le_bytes = [0; 8]; - self.fill_slice(&mut le_bytes)?; - Ok(u64::from_le_bytes(le_bytes)) - } - - fn parse_u128(&mut self) -> Result { - let mut le_bytes = [0; 16]; - self.fill_slice(&mut le_bytes)?; - Ok(u128::from_le_bytes(le_bytes)) - } - - #[allow(clippy::integer_arithmetic)] - fn parse_u32_from_uleb128(&mut self) -> Result { - let mut value: u64 = 0; - for shift in (0..32).step_by(7) { - let byte = self.next()?; - let digit = byte & 0x7f; - value |= u64::from(digit) << shift; - // If the highest bit of `byte` is 0, return the final value. - if digit == byte { - if shift > 0 && digit == 0 { - // We only accept canonical ULEB128 encodings, therefore the - // heaviest (and last) base-128 digit must be non-zero. - return Err(Error::NonCanonicalUleb128Encoding); - } - // Decoded integer must not overflow. - return u32::try_from(value) - .map_err(|_| Error::IntegerOverflowDuringUleb128Decoding); - } - } - // Decoded integer must not overflow. - Err(Error::IntegerOverflowDuringUleb128Decoding) - } - - fn parse_length(&mut self) -> Result { - let len = self.parse_u32_from_uleb128()? as usize; - if len > crate::MAX_SEQUENCE_LENGTH { - return Err(Error::ExceededMaxLen(len)); - } - Ok(len) + fn parse_string_borrowed(&mut self) -> Result<&'de str> { + let slice = self.parse_bytes_borrowed()?; + std::str::from_utf8(slice).map_err(|_| Error::Utf8) } - fn parse_bytes(&mut self) -> Result<&'de [u8]> { + fn parse_bytes_borrowed(&mut self) -> Result<&'de [u8]> { let len = self.parse_length()?; let slice = self.input.get(..len).ok_or(Error::Eof)?; self.input = &self.input[len..]; Ok(slice) } - - fn parse_string(&mut self) -> Result<&'de str> { - let slice = self.parse_bytes()?; - std::str::from_utf8(slice).map_err(|_| Error::Utf8) - } - - fn enter_named_container(&mut self, name: &'static str) -> Result<()> { - if self.max_remaining_depth == 0 { - return Err(Error::ExceededContainerDepthLimit(name)); - } - self.max_remaining_depth -= 1; - Ok(()) - } - - fn leave_named_container(&mut self) { - self.max_remaining_depth += 1; - } } impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { @@ -306,7 +233,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - visitor.visit_borrowed_str(self.parse_string()?) + visitor.visit_borrowed_str(self.parse_string_borrowed()?) } fn deserialize_string(self, visitor: V) -> Result @@ -320,7 +247,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - visitor.visit_borrowed_bytes(self.parse_bytes()?) + visitor.visit_borrowed_bytes(self.parse_bytes_borrowed()?) } fn deserialize_byte_buf(self, visitor: V) -> Result diff --git a/src/de/from_reader.rs b/src/de/from_reader.rs new file mode 100644 index 0000000..e550cbc --- /dev/null +++ b/src/de/from_reader.rs @@ -0,0 +1,475 @@ +use crate::error::{Error, Result}; +use serde::de::{self, DeserializeSeed, IntoDeserializer, Visitor}; +use std::io::Read; + +use super::BcsDeserializer; + +/// Deserializes BCS from a [`std::io::Read`]er +pub struct DeserializeReader<'de, R> { + reader: TeeReader<'de, R>, + max_remaining_depth: usize, +} + +impl<'de, R: Read> DeserializeReader<'de, R> { + /// Wraps the provided reader in a new [`DeserializeReader`] + fn new(reader: &'de mut R, max_remaining_depth: usize) -> Self { + DeserializeReader { + reader: TeeReader::new(reader), + max_remaining_depth, + } + } +} + +impl<'de, R: Read> BcsDeserializer for DeserializeReader<'de, R> { + fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { + Ok(self.reader.read_exact(&mut slice[..])?) + } + + fn max_remaining_depth(&mut self) -> usize { + self.max_remaining_depth + } + + fn max_remaining_depth_mut(&mut self) -> &mut usize { + &mut self.max_remaining_depth + } +} + +impl<'de, R: Read> DeserializeReader<'de, R> { + /// Parse a vector of bytes from the reader + fn parse_vec(&mut self) -> Result> { + let len = self.parse_length()?; + let mut output = vec![0; len]; + self.fill_slice(&mut output)?; + Ok(output) + } + + /// Parse a String from the reader + fn parse_string(&mut self) -> Result { + let bytes = self.parse_vec()?; + String::from_utf8(bytes).map_err(|_| Error::Utf8) + } +} + +impl<'de, 'a, R: Read> de::Deserializer<'de> for &'a mut DeserializeReader<'de, R> { + type Error = Error; + + // BCS is not a self-describing format so we can't implement `deserialize_any` + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_any")) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_bool(self.parse_bool()?) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i8(self.parse_u8()? as i8) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i16(self.parse_u16()? as i16) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i32(self.parse_u32()? as i32) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i64(self.parse_u64()? as i64) + } + + fn deserialize_i128(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i128(self.parse_u128()? as i128) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u8(self.parse_u8()?) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u16(self.parse_u16()?) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u32(self.parse_u32()?) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u64(self.parse_u64()?) + } + + fn deserialize_u128(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u128(self.parse_u128()?) + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_f32")) + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_f64")) + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_char")) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_string(self.parse_string()?) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_byte_buf(self.parse_vec()?) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_bytes(visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let byte = self.next()?; + + match byte { + 0 => visitor.visit_none(), + 1 => visitor.visit_some(self), + _ => Err(Error::ExpectedOption), + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = self.deserialize_unit(visitor); + self.leave_named_container(); + r + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_newtype_struct(&mut *self); + self.leave_named_container(); + r + } + #[allow(clippy::needless_borrow)] + fn deserialize_seq(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.parse_length()?; + visitor.visit_seq(SeqDeserializer::new(&mut self, len)) + } + #[allow(clippy::needless_borrow)] + fn deserialize_tuple(mut self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(SeqDeserializer::new(&mut self, len)) + } + #[allow(clippy::needless_borrow)] + fn deserialize_tuple_struct( + mut self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_seq(SeqDeserializer::new(&mut self, len)); + self.leave_named_container(); + r + } + #[allow(clippy::needless_borrow)] + fn deserialize_map(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.parse_length()?; + visitor.visit_map(MapDeserializer::new(&mut self, len)) + } + #[allow(clippy::needless_borrow)] + fn deserialize_struct( + mut self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_seq(SeqDeserializer::new(&mut self, fields.len())); + self.leave_named_container(); + r + } + + fn deserialize_enum( + self, + name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_enum(&mut *self); + self.leave_named_container(); + r + } + + // BCS does not utilize identifiers, so throw them away + fn deserialize_identifier(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_bytes(_visitor) + } + + // BCS is not a self-describing format so we can't implement `deserialize_ignored_any` + fn deserialize_ignored_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_ignored_any")) + } + + // BCS is not a human readable format + fn is_human_readable(&self) -> bool { + false + } +} + +struct SeqDeserializer<'a, 'de: 'a, R> { + de: &'a mut DeserializeReader<'de, R>, + remaining: usize, +} +#[allow(clippy::needless_borrow)] +impl<'a, 'de: 'a, R> SeqDeserializer<'a, 'de, R> { + fn new(de: &'a mut DeserializeReader<'de, R>, remaining: usize) -> Self { + Self { de, remaining } + } +} + +impl<'de, 'a, R: Read> de::SeqAccess<'de> for SeqDeserializer<'a, 'de, R> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + if self.remaining == 0 { + Ok(None) + } else { + self.remaining -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } +} + +/// A reader that can optionally capture all bytes from an underlying [`Read`]er +pub struct TeeReader<'a, R> { + reader: &'a mut R, + capture_buffer: Option>, +} + +impl<'a, R> TeeReader<'a, R> { + /// Wrapse the provided reader in a new [`TeeReader`]. + pub fn new(reader: &'a mut R) -> Self { + Self { + reader, + capture_buffer: Default::default(), + } + } +} + +impl<'a, R: Read> Read for TeeReader<'a, R> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let bytes_read = self.reader.read(buf)?; + if let Some(ref mut buffer) = self.capture_buffer { + buffer.extend_from_slice(&buf[..bytes_read]); + } + Ok(bytes_read) + } +} + +struct MapDeserializer<'a, 'de: 'a, R> { + de: &'a mut DeserializeReader<'de, R>, + remaining: usize, + previous_key_bytes: Option>, +} + +impl<'a, 'de, R: Read> MapDeserializer<'a, 'de, R> { + fn new(de: &'a mut DeserializeReader<'de, R>, remaining: usize) -> Self { + Self { + de, + remaining, + previous_key_bytes: None, + } + } +} + +impl<'de, 'a, R: Read> de::MapAccess<'de> for MapDeserializer<'a, 'de, R> +where + 'de: 'a, +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + match self.remaining.checked_sub(1) { + None => Ok(None), + Some(remaining) => { + self.de.reader.capture_buffer = Some(Vec::new()); + let key_value = seed.deserialize(&mut *self.de)?; + let key_bytes = self.de.reader.capture_buffer.take().unwrap(); + + if let Some(ref previous_key_bytes) = self.previous_key_bytes { + if previous_key_bytes.as_slice() >= key_bytes.as_slice() { + return Err(Error::NonCanonicalMap); + } + } + self.remaining = remaining; + self.previous_key_bytes = Some(key_bytes); + Ok(Some(key_value)) + } + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.de) + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } +} + +impl<'a, 'de: 'a, R: Read> de::EnumAccess<'de> for &'a mut DeserializeReader<'de, R> { + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: DeserializeSeed<'de>, + { + let variant_index = self.parse_u32_from_uleb128()?; + let result: Result = seed.deserialize(variant_index.into_deserializer()); + Ok((result?, self)) + } +} + +impl<'a, 'de: 'a, R: Read> de::VariantAccess<'de> for &'a mut DeserializeReader<'de, R> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + seed.deserialize(self) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + de::Deserializer::deserialize_tuple(self, len, visitor) + } + + fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'de>, + { + de::Deserializer::deserialize_tuple(self, fields.len(), visitor) + } +} diff --git a/src/de/mod.rs b/src/de/mod.rs new file mode 100644 index 0000000..554fa9f --- /dev/null +++ b/src/de/mod.rs @@ -0,0 +1,107 @@ +// Copyright (c) The Diem Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::error::{Error, Result}; + +use std::{ + convert::TryFrom, + ops::{AddAssign, SubAssign}, +}; + +mod from_bytes; +pub use from_bytes::*; +mod from_reader; + +trait BcsDeserializer { + fn max_remaining_depth(&mut self) -> usize; + fn max_remaining_depth_mut(&mut self) -> &mut usize; + + fn parse_bool(&mut self) -> Result { + let byte = self.next()?; + + match byte { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(Error::ExpectedBoolean), + } + } + + fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()>; + + fn next(&mut self) -> Result { + let mut byte = [0u8; 1]; + self.fill_slice(&mut byte)?; + Ok(byte[0]) + } + + fn parse_u8(&mut self) -> Result { + self.next() + } + + fn parse_u16(&mut self) -> Result { + let mut le_bytes = [0; 2]; + self.fill_slice(&mut le_bytes)?; + Ok(u16::from_le_bytes(le_bytes)) + } + + fn parse_u32(&mut self) -> Result { + let mut le_bytes = [0; 4]; + self.fill_slice(&mut le_bytes)?; + Ok(u32::from_le_bytes(le_bytes)) + } + + fn parse_u64(&mut self) -> Result { + let mut le_bytes = [0; 8]; + self.fill_slice(&mut le_bytes)?; + Ok(u64::from_le_bytes(le_bytes)) + } + + fn parse_u128(&mut self) -> Result { + let mut le_bytes = [0; 16]; + self.fill_slice(&mut le_bytes)?; + Ok(u128::from_le_bytes(le_bytes)) + } + + #[allow(clippy::integer_arithmetic)] + fn parse_u32_from_uleb128(&mut self) -> Result { + let mut value: u64 = 0; + for shift in (0..32).step_by(7) { + let byte = self.next()?; + let digit = byte & 0x7f; + value |= u64::from(digit) << shift; + // If the highest bit of `byte` is 0, return the final value. + if digit == byte { + if shift > 0 && digit == 0 { + // We only accept canonical ULEB128 encodings, therefore the + // heaviest (and last) base-128 digit must be non-zero. + return Err(Error::NonCanonicalUleb128Encoding); + } + // Decoded integer must not overflow. + return u32::try_from(value) + .map_err(|_| Error::IntegerOverflowDuringUleb128Decoding); + } + } + // Decoded integer must not overflow. + Err(Error::IntegerOverflowDuringUleb128Decoding) + } + + fn parse_length(&mut self) -> Result { + let len = self.parse_u32_from_uleb128()? as usize; + if len > crate::MAX_SEQUENCE_LENGTH { + return Err(Error::ExceededMaxLen(len)); + } + Ok(len) + } + + fn enter_named_container(&mut self, name: &'static str) -> Result<()> { + if self.max_remaining_depth() == 0 { + return Err(Error::ExceededContainerDepthLimit(name)); + } + self.max_remaining_depth_mut().sub_assign(1); + Ok(()) + } + + fn leave_named_container(&mut self) { + self.max_remaining_depth_mut().add_assign(1); + } +} From 044c2a20d0f5dc16caaea2ce64a6ba9c68480992 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 12:11:57 -0500 Subject: [PATCH 2/8] Refactor; use generics --- src/de/mod.rs | 579 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 564 insertions(+), 15 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index 554fa9f..502a401 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -2,19 +2,91 @@ // SPDX-License-Identifier: Apache-2.0 use crate::error::{Error, Result}; +use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, Visitor}; +use std::convert::TryFrom; -use std::{ - convert::TryFrom, - ops::{AddAssign, SubAssign}, -}; +/// Deserializes a `&[u8]` into a type. +/// +/// This function will attempt to interpret `bytes` as the BCS serialized form of `T` and +/// deserialize `T` from `bytes`. +/// +/// # Examples +/// +/// ``` +/// use bcs::from_bytes; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Ip([u8; 4]); +/// +/// #[derive(Deserialize)] +/// struct Port(u16); +/// +/// #[derive(Deserialize)] +/// struct SocketAddr { +/// ip: Ip, +/// port: Port, +/// } +/// +/// let bytes = vec![0x7f, 0x00, 0x00, 0x01, 0x41, 0x1f]; +/// let socket_addr: SocketAddr = from_bytes(&bytes).unwrap(); +/// +/// assert_eq!(socket_addr.ip.0, [127, 0, 0, 1]); +/// assert_eq!(socket_addr.port.0, 8001); +/// ``` +pub fn from_bytes<'a, T>(bytes: &'a [u8]) -> Result +where + T: Deserialize<'a>, +{ + let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); + let t = T::deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) +} + +/// Perform a stateful deserialization from a `&[u8]` using the provided `seed`. +pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result +where + T: DeserializeSeed<'a>, +{ + let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); + let t = seed.deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) +} + +/// Deserialization implementation for BCS +struct Deserializer<'de, R: ?Sized> { + input: &'de R, + max_remaining_depth: usize, +} -mod from_bytes; -pub use from_bytes::*; -mod from_reader; +impl<'de, R: ?Sized> Deserializer<'de, R> { + /// Creates a new `Deserializer` which will be deserializing the provided + /// input. + fn new(input: &'de R, max_remaining_depth: usize) -> Self { + Deserializer { + input, + max_remaining_depth, + } + } +} -trait BcsDeserializer { - fn max_remaining_depth(&mut self) -> usize; - fn max_remaining_depth_mut(&mut self) -> &mut usize; +trait BcsDeserializer<'de> { + type MaybeBorrowedBytes: AsRef<[u8]>; + + fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()>; + + fn parse_and_visit_str(&mut self, visitor: V) -> Result + where + V: Visitor<'de>; + + fn parse_and_visit_bytes(&mut self, visitor: V) -> Result + where + V: Visitor<'de>; + + fn next_key_seed>( + &mut self, + seed: K, + ) -> Result<(K::Value, Self::MaybeBorrowedBytes), Error>; fn parse_bool(&mut self) -> Result { let byte = self.next()?; @@ -26,8 +98,6 @@ trait BcsDeserializer { } } - fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()>; - fn next(&mut self) -> Result { let mut byte = [0u8; 1]; self.fill_slice(&mut byte)?; @@ -92,16 +162,495 @@ trait BcsDeserializer { } Ok(len) } +} + +impl<'de> BcsDeserializer<'de> for Deserializer<'de, [u8]> { + type MaybeBorrowedBytes = &'de [u8]; + fn next(&mut self) -> Result { + let byte = self.peek()?; + self.input = &self.input[1..]; + Ok(byte) + } + + fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { + for byte in slice { + *byte = self.next()?; + } + Ok(()) + } + + fn parse_and_visit_str(&mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.parse_string()?) + } + + fn parse_and_visit_bytes(&mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.parse_bytes()?) + } + + fn next_key_seed>( + &mut self, + seed: K, + ) -> Result<(K::Value, Self::MaybeBorrowedBytes), Error> { + let previous_input_slice = self.input; + let key_value = seed.deserialize(&mut *self)?; + let key_len = previous_input_slice.len().saturating_sub(self.input.len()); + let key_bytes = &previous_input_slice[..key_len]; + Ok((key_value, key_bytes)) + } +} + +impl<'de> Deserializer<'de, [u8]> { + fn peek(&mut self) -> Result { + self.input.first().copied().ok_or(Error::Eof) + } + + /// The `Deserializer::end` method should be called after a type has been + /// fully deserialized. This allows the `Deserializer` to validate that + /// the there are no more bytes remaining in the input stream. + fn end(&mut self) -> Result<()> { + if self.input.is_empty() { + Ok(()) + } else { + Err(Error::RemainingInput) + } + } + + fn parse_bytes(&mut self) -> Result<&'de [u8]> { + let len = self.parse_length()?; + let slice = self.input.get(..len).ok_or(Error::Eof)?; + self.input = &self.input[len..]; + Ok(slice) + } + fn parse_string(&mut self) -> Result<&'de str> { + let slice = self.parse_bytes()?; + std::str::from_utf8(slice).map_err(|_| Error::Utf8) + } +} + +impl<'de, R: ?Sized> Deserializer<'de, R> { fn enter_named_container(&mut self, name: &'static str) -> Result<()> { - if self.max_remaining_depth() == 0 { + if self.max_remaining_depth == 0 { return Err(Error::ExceededContainerDepthLimit(name)); } - self.max_remaining_depth_mut().sub_assign(1); + self.max_remaining_depth -= 1; Ok(()) } fn leave_named_container(&mut self) { - self.max_remaining_depth_mut().add_assign(1); + self.max_remaining_depth += 1; + } +} + +impl<'de, 'a, R: ?Sized> de::Deserializer<'de> for &'a mut Deserializer<'de, R> +where + Deserializer<'de, R>: BcsDeserializer<'de>, +{ + type Error = Error; + + // BCS is not a self-describing format so we can't implement `deserialize_any` + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_any")) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_bool(self.parse_bool()?) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i8(self.parse_u8()? as i8) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i16(self.parse_u16()? as i16) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i32(self.parse_u32()? as i32) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i64(self.parse_u64()? as i64) + } + + fn deserialize_i128(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i128(self.parse_u128()? as i128) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u8(self.parse_u8()?) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u16(self.parse_u16()?) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u32(self.parse_u32()?) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u64(self.parse_u64()?) + } + + fn deserialize_u128(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u128(self.parse_u128()?) + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_f32")) + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_f64")) + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_char")) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.parse_and_visit_str(visitor) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.parse_and_visit_str(visitor) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.parse_and_visit_bytes(visitor) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.parse_and_visit_bytes(visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let byte = self.next()?; + + match byte { + 0 => visitor.visit_none(), + 1 => visitor.visit_some(self), + _ => Err(Error::ExpectedOption), + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = self.deserialize_unit(visitor); + self.leave_named_container(); + r + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_newtype_struct(&mut *self); + self.leave_named_container(); + r + } + #[allow(clippy::needless_borrow)] + fn deserialize_seq(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.parse_length()?; + visitor.visit_seq(SeqDeserializer::new(&mut self, len)) + } + #[allow(clippy::needless_borrow)] + fn deserialize_tuple(mut self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(SeqDeserializer::new(&mut self, len)) + } + #[allow(clippy::needless_borrow)] + fn deserialize_tuple_struct( + mut self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_seq(SeqDeserializer::new(&mut self, len)); + self.leave_named_container(); + r + } + #[allow(clippy::needless_borrow)] + fn deserialize_map(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.parse_length()?; + visitor.visit_map(MapDeserializer::new(&mut self, len)) + } + #[allow(clippy::needless_borrow)] + fn deserialize_struct( + mut self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_seq(SeqDeserializer::new(&mut self, fields.len())); + self.leave_named_container(); + r + } + + fn deserialize_enum( + self, + name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.enter_named_container(name)?; + let r = visitor.visit_enum(&mut *self); + self.leave_named_container(); + r + } + + // BCS does not utilize identifiers, so throw them away + fn deserialize_identifier(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_bytes(_visitor) + } + + // BCS is not a self-describing format so we can't implement `deserialize_ignored_any` + fn deserialize_ignored_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::NotSupported("deserialize_ignored_any")) + } + + // BCS is not a human readable format + fn is_human_readable(&self) -> bool { + false + } +} + +struct SeqDeserializer<'a, 'de: 'a, R: ?Sized> { + de: &'a mut Deserializer<'de, R>, + remaining: usize, +} +#[allow(clippy::needless_borrow)] +impl<'a, 'de, R: ?Sized> SeqDeserializer<'a, 'de, R> { + fn new(de: &'a mut Deserializer<'de, R>, remaining: usize) -> Self { + Self { de, remaining } + } +} + +impl<'de, 'a, R: ?Sized> de::SeqAccess<'de> for SeqDeserializer<'a, 'de, R> +where + Deserializer<'de, R>: BcsDeserializer<'de>, +{ + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + if self.remaining == 0 { + Ok(None) + } else { + self.remaining -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } +} + +struct MapDeserializer<'a, 'de: 'a, R: ?Sized, B> { + de: &'a mut Deserializer<'de, R>, + remaining: usize, + previous_key_bytes: Option, +} + +impl<'a, 'de, R: ?Sized, B> MapDeserializer<'a, 'de, R, B> { + fn new(de: &'a mut Deserializer<'de, R>, remaining: usize) -> Self { + Self { + de, + remaining, + previous_key_bytes: None, + } + } +} + +impl<'de, 'a, R: ?Sized, B: AsRef<[u8]>> de::MapAccess<'de> for MapDeserializer<'a, 'de, R, B> +where + Deserializer<'de, R>: BcsDeserializer<'de, MaybeBorrowedBytes = B>, +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + match self.remaining.checked_sub(1) { + None => Ok(None), + Some(remaining) => { + let (key_value, key_bytes) = self.de.next_key_seed(seed)?; + if let Some(ref previous_key_bytes) = self.previous_key_bytes { + if previous_key_bytes.as_ref() >= key_bytes.as_ref() { + return Err(Error::NonCanonicalMap); + } + } + self.remaining = remaining; + self.previous_key_bytes = Some(key_bytes); + Ok(Some(key_value)) + } + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.de) + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } +} + +impl<'de, 'a, R: ?Sized> de::EnumAccess<'de> for &'a mut Deserializer<'de, R> +where + Deserializer<'de, R>: BcsDeserializer<'de>, +{ + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: DeserializeSeed<'de>, + { + let variant_index = self.parse_u32_from_uleb128()?; + let result: Result = seed.deserialize(variant_index.into_deserializer()); + Ok((result?, self)) + } +} + +impl<'de, 'a, R: ?Sized> de::VariantAccess<'de> for &'a mut Deserializer<'de, R> +where + Deserializer<'de, R>: BcsDeserializer<'de>, +{ + type Error = Error; + + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + seed.deserialize(self) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + de::Deserializer::deserialize_tuple(self, len, visitor) + } + + fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'de>, + { + de::Deserializer::deserialize_tuple(self, fields.len(), visitor) } } From 8987638a50e36f07643ff4483e8de82ab019e07d Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 12:39:36 -0500 Subject: [PATCH 3/8] Complete implementation --- src/de/mod.rs | 129 +++++++++++++++++++++++++++++++++++++++++++------- src/lib.rs | 2 +- 2 files changed, 112 insertions(+), 19 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index 502a401..ad8de46 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use crate::error::{Error, Result}; -use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, Visitor}; -use std::convert::TryFrom; +use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, IntoDeserializer, Visitor}; +use std::{convert::TryFrom, io::Read}; /// Deserializes a `&[u8]` into a type. /// @@ -53,23 +53,70 @@ where deserializer.end().map(move |_| t) } +/// Deserialize a type from an implementation of [`Read`]. +pub fn from_reader(mut reader: &mut impl Read) -> Result +where + T: DeserializeOwned, +{ + let mut deserializer = Deserializer::from_reader(&mut reader, crate::MAX_CONTAINER_DEPTH); + T::deserialize(&mut deserializer) +} + /// Deserialization implementation for BCS -struct Deserializer<'de, R: ?Sized> { - input: &'de R, +struct Deserializer<'de, R> { + input: R, max_remaining_depth: usize, + _phantom: std::marker::PhantomData<&'de ()>, } -impl<'de, R: ?Sized> Deserializer<'de, R> { +impl<'de, R: Read> Deserializer<'de, TeeReader<&'de mut R>> { + fn from_reader(input: &'de mut R, max_remaining_depth: usize) -> Self { + Deserializer { + input: TeeReader::new(input), + max_remaining_depth, + _phantom: std::marker::PhantomData, + } + } +} + +impl<'de> Deserializer<'de, &'de [u8]> { /// Creates a new `Deserializer` which will be deserializing the provided /// input. - fn new(input: &'de R, max_remaining_depth: usize) -> Self { + fn new(input: &'de [u8], max_remaining_depth: usize) -> Self { Deserializer { input, max_remaining_depth, + _phantom: std::marker::PhantomData, + } + } +} + +/// A reader that can optionally capture all bytes from an underlying [`Read`]er +struct TeeReader { + reader: R, + capture_buffer: Option>, +} + +impl TeeReader { + /// Wrapse the provided reader in a new [`TeeReader`]. + pub fn new(reader: R) -> Self { + Self { + reader, + capture_buffer: Default::default(), } } } +impl Read for TeeReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let bytes_read = self.reader.read(buf)?; + if let Some(ref mut buffer) = self.capture_buffer { + buffer.extend_from_slice(&buf[..bytes_read]); + } + Ok(bytes_read) + } +} + trait BcsDeserializer<'de> { type MaybeBorrowedBytes: AsRef<[u8]>; @@ -164,7 +211,53 @@ trait BcsDeserializer<'de> { } } -impl<'de> BcsDeserializer<'de> for Deserializer<'de, [u8]> { +impl<'de, R: Read> Deserializer<'de, TeeReader> { + fn parse_vec(&mut self) -> Result> { + let len = self.parse_length()?; + let mut output = vec![0; len]; + self.fill_slice(&mut output)?; + Ok(output) + } + + fn parse_string(&mut self) -> Result { + let vec = self.parse_vec()?; + String::from_utf8(vec).map_err(|_| Error::Utf8) + } +} + +impl<'de, R: Read> BcsDeserializer<'de> for Deserializer<'de, TeeReader> { + type MaybeBorrowedBytes = Vec; + + fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { + Ok(self.input.read_exact(slice)?) + } + + fn parse_and_visit_str(&mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_string(self.parse_string()?) + } + + fn parse_and_visit_bytes(&mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_byte_buf(self.parse_vec()?) + } + + fn next_key_seed>( + &mut self, + seed: K, + ) -> Result<(K::Value, Self::MaybeBorrowedBytes), Error> { + self.input.capture_buffer = Some(Vec::new()); + let key_value = seed.deserialize(&mut *self)?; + let key_bytes = self.input.capture_buffer.take().unwrap(); + Ok((key_value, key_bytes)) + } +} + +impl<'de> BcsDeserializer<'de> for Deserializer<'de, &'de [u8]> { type MaybeBorrowedBytes = &'de [u8]; fn next(&mut self) -> Result { let byte = self.peek()?; @@ -205,7 +298,7 @@ impl<'de> BcsDeserializer<'de> for Deserializer<'de, [u8]> { } } -impl<'de> Deserializer<'de, [u8]> { +impl<'de> Deserializer<'de, &'de [u8]> { fn peek(&mut self) -> Result { self.input.first().copied().ok_or(Error::Eof) } @@ -234,7 +327,7 @@ impl<'de> Deserializer<'de, [u8]> { } } -impl<'de, R: ?Sized> Deserializer<'de, R> { +impl<'de, R> Deserializer<'de, R> { fn enter_named_container(&mut self, name: &'static str) -> Result<()> { if self.max_remaining_depth == 0 { return Err(Error::ExceededContainerDepthLimit(name)); @@ -248,7 +341,7 @@ impl<'de, R: ?Sized> Deserializer<'de, R> { } } -impl<'de, 'a, R: ?Sized> de::Deserializer<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<'de, R> where Deserializer<'de, R>: BcsDeserializer<'de>, { @@ -518,18 +611,18 @@ where } } -struct SeqDeserializer<'a, 'de: 'a, R: ?Sized> { +struct SeqDeserializer<'a, 'de: 'a, R> { de: &'a mut Deserializer<'de, R>, remaining: usize, } #[allow(clippy::needless_borrow)] -impl<'a, 'de, R: ?Sized> SeqDeserializer<'a, 'de, R> { +impl<'a, 'de, R> SeqDeserializer<'a, 'de, R> { fn new(de: &'a mut Deserializer<'de, R>, remaining: usize) -> Self { Self { de, remaining } } } -impl<'de, 'a, R: ?Sized> de::SeqAccess<'de> for SeqDeserializer<'a, 'de, R> +impl<'de, 'a, R> de::SeqAccess<'de> for SeqDeserializer<'a, 'de, R> where Deserializer<'de, R>: BcsDeserializer<'de>, { @@ -552,13 +645,13 @@ where } } -struct MapDeserializer<'a, 'de: 'a, R: ?Sized, B> { +struct MapDeserializer<'a, 'de: 'a, R, B> { de: &'a mut Deserializer<'de, R>, remaining: usize, previous_key_bytes: Option, } -impl<'a, 'de, R: ?Sized, B> MapDeserializer<'a, 'de, R, B> { +impl<'a, 'de, R, B> MapDeserializer<'a, 'de, R, B> { fn new(de: &'a mut Deserializer<'de, R>, remaining: usize) -> Self { Self { de, @@ -568,7 +661,7 @@ impl<'a, 'de, R: ?Sized, B> MapDeserializer<'a, 'de, R, B> { } } -impl<'de, 'a, R: ?Sized, B: AsRef<[u8]>> de::MapAccess<'de> for MapDeserializer<'a, 'de, R, B> +impl<'de, 'a, R, B: AsRef<[u8]>> de::MapAccess<'de> for MapDeserializer<'a, 'de, R, B> where Deserializer<'de, R>: BcsDeserializer<'de, MaybeBorrowedBytes = B>, { @@ -606,7 +699,7 @@ where } } -impl<'de, 'a, R: ?Sized> de::EnumAccess<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R> de::EnumAccess<'de> for &'a mut Deserializer<'de, R> where Deserializer<'de, R>: BcsDeserializer<'de>, { @@ -623,7 +716,7 @@ where } } -impl<'de, 'a, R: ?Sized> de::VariantAccess<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R> de::VariantAccess<'de> for &'a mut Deserializer<'de, R> where Deserializer<'de, R>: BcsDeserializer<'de>, { diff --git a/src/lib.rs b/src/lib.rs index c8ee6a4..3f4d452 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -314,6 +314,6 @@ pub const MAX_SEQUENCE_LENGTH: usize = (1 << 31) - 1; /// Maximal allowed depth of BCS data, counting only structs and enums. pub const MAX_CONTAINER_DEPTH: usize = 500; -pub use de::{from_bytes, from_bytes_seed}; +pub use de::{from_bytes, from_bytes_seed, from_reader}; pub use error::{Error, Result}; pub use ser::{is_human_readable, serialize_into, serialized_size, to_bytes}; From 5d352b05f0195a720a77bb697d008aebd290fab3 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 13:08:13 -0500 Subject: [PATCH 4/8] test --- src/error.rs | 8 ++- src/test_helpers.rs | 4 ++ tests/serde.rs | 119 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 126 insertions(+), 5 deletions(-) diff --git a/src/error.rs b/src/error.rs index 3a913e4..dad27b9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use serde::{de, ser}; -use std::fmt; +use std::{fmt, io::ErrorKind}; use thiserror::Error; pub type Result = std::result::Result; @@ -45,7 +45,11 @@ pub enum Error { impl From for Error { fn from(err: std::io::Error) -> Self { - Error::Io(err.to_string()) + if err.kind() == ErrorKind::UnexpectedEof { + Error::Eof + } else { + Error::Io(err.to_string()) + } } } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 4c7e51f..8f361e5 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -8,4 +8,8 @@ where let bytes = crate::to_bytes(&t).unwrap(); let s: T = crate::from_bytes(&bytes).unwrap(); assert_eq!(t, s); + + let mut reader = std::io::Cursor::new(bytes); + let s_from_reader = crate::from_reader(&mut reader).unwrap(); + assert_eq!(t, s_from_reader); } diff --git a/tests/serde.rs b/tests/serde.rs index 77053a9..0d49e0b 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -13,7 +13,17 @@ use proptest::prelude::*; use proptest_derive::Arbitrary; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use bcs::{from_bytes, serialized_size, to_bytes, Error, MAX_CONTAINER_DEPTH, MAX_SEQUENCE_LENGTH}; +use bcs::{ + from_bytes, from_reader, serialized_size, to_bytes, Error, MAX_CONTAINER_DEPTH, + MAX_SEQUENCE_LENGTH, +}; + +/// A helper function to attempt deserialization via reader +fn from_bytes_via_reader(bytes: &[u8]) -> Result { + let mut reader = std::io::Cursor::new(bytes); + let s_from_reader = from_reader(&mut reader)?; + Ok(s_from_reader) +} fn is_same(t: T) where @@ -23,6 +33,9 @@ where let s: T = from_bytes(&bytes).unwrap(); assert_eq!(t, s); assert_eq!(bytes.len(), serialized_size(&t).unwrap()); + + let s_from_reader = from_bytes_via_reader(&bytes).unwrap(); + assert_eq!(t, s_from_reader); } // TODO deriving `Arbitrary` is currently broken for enum types @@ -254,6 +267,10 @@ proptest! { fn invalid_utf8() { let invalid_utf8 = vec![1, 0xFF]; assert_eq!(from_bytes::(&invalid_utf8), Err(Error::Utf8)); + assert_eq!( + from_bytes_via_reader::(&invalid_utf8), + Err(Error::Utf8) + ); } #[test] @@ -266,6 +283,7 @@ fn uleb_encoding_and_variant() { let valid_variant = vec![1]; from_bytes::(&valid_variant).unwrap(); + from_bytes_via_reader::(&valid_variant).unwrap(); let invalid_variant = vec![5]; // Error comes from serde @@ -275,10 +293,20 @@ fn uleb_encoding_and_variant() { "invalid value: integer `5`, expected variant index 0 <= i < 2".into() )) ); + assert_eq!( + from_bytes_via_reader::(&invalid_variant), + Err(Error::Custom( + "invalid value: integer `5`, expected variant index 0 <= i < 2".into() + )) + ); let invalid_bytes = vec![0x80, 0x80, 0x80, 0x80]; // Error is due to EOF. assert_eq!(from_bytes::(&invalid_bytes), Err(Error::Eof)); + assert_eq!( + from_bytes_via_reader::(&invalid_bytes), + Err(Error::Eof) + ); let invalid_uleb = vec![0x80, 0x80, 0x80, 0x80, 0x80]; // Error comes from uleb decoder because u32 are never that long. @@ -286,6 +314,10 @@ fn uleb_encoding_and_variant() { from_bytes::(&invalid_uleb), Err(Error::IntegerOverflowDuringUleb128Decoding) ); + assert_eq!( + from_bytes_via_reader::(&invalid_uleb), + Err(Error::IntegerOverflowDuringUleb128Decoding) + ); let invalid_uleb = vec![0x80, 0x80, 0x80, 0x80, 0x1f]; // Error comes from uleb decoder because we are truncating a larger integer into u32. @@ -293,6 +325,10 @@ fn uleb_encoding_and_variant() { from_bytes::(&invalid_uleb), Err(Error::IntegerOverflowDuringUleb128Decoding) ); + assert_eq!( + from_bytes_via_reader::(&invalid_uleb), + Err(Error::IntegerOverflowDuringUleb128Decoding) + ); let invalid_uleb = vec![0x80, 0x80, 0x80, 0x80, 0x0f]; // Error comes from Serde because ULEB integer is valid. @@ -302,6 +338,12 @@ fn uleb_encoding_and_variant() { "invalid value: integer `4026531840`, expected variant index 0 <= i < 2".into() )) ); + assert_eq!( + from_bytes_via_reader::(&invalid_uleb), + Err(Error::Custom( + "invalid value: integer `4026531840`, expected variant index 0 <= i < 2".into() + )) + ); let invalid_uleb = vec![0x80, 0x80, 0x80, 0x00]; // Uleb decoder must reject non-canonical forms. @@ -309,6 +351,10 @@ fn uleb_encoding_and_variant() { from_bytes::(&invalid_uleb), Err(Error::NonCanonicalUleb128Encoding) ); + assert_eq!( + from_bytes_via_reader::(&invalid_uleb), + Err(Error::NonCanonicalUleb128Encoding) + ); } #[test] @@ -318,6 +364,10 @@ fn invalid_option() { from_bytes::>(&invalid_option), Err(Error::ExpectedOption) ); + assert_eq!( + from_bytes_via_reader::>(&invalid_option), + Err(Error::ExpectedOption) + ); } #[test] @@ -327,6 +377,10 @@ fn invalid_bool() { from_bytes::(&invalid_bool), Err(Error::ExpectedBoolean) ); + assert_eq!( + from_bytes_via_reader::(&invalid_bool), + Err(Error::ExpectedBoolean) + ); } #[test] @@ -353,6 +407,7 @@ fn variable_lengths() { fn sequence_not_long_enough() { let seq = vec![5, 1, 2, 3, 4]; // Missing 5th element assert_eq!(from_bytes::>(&seq), Err(Error::Eof)); + assert_eq!(from_bytes_via_reader::>(&seq), Err(Error::Eof)); } #[test] @@ -361,19 +416,28 @@ fn map_not_canonical() { map.insert(4u8, ()); map.insert(5u8, ()); let seq = vec![2, 4, 5]; - assert_eq!(from_bytes::>(&seq), Ok(map)); + assert_eq!(from_bytes::>(&seq).as_ref(), Ok(&map)); + assert_eq!(from_bytes_via_reader::>(&seq), Ok(map)); // Make sure out-of-order keys are rejected. let seq = vec![2, 5, 4]; assert_eq!( from_bytes::>(&seq), Err(Error::NonCanonicalMap) ); + assert_eq!( + from_bytes_via_reader::>(&seq), + Err(Error::NonCanonicalMap) + ); // Make sure duplicate keys are rejected. let seq = vec![2, 5, 5]; assert_eq!( from_bytes::>(&seq), Err(Error::NonCanonicalMap) ); + assert_eq!( + from_bytes_via_reader::>(&seq), + Err(Error::NonCanonicalMap) + ); } #[test] @@ -385,11 +449,14 @@ fn by_default_btreesets_are_serialized_as_sequences() { set.insert(5u8); let seq = vec![2, 4, 5]; assert_eq!(from_bytes::>(&seq), Ok(set.clone())); + assert_eq!(from_bytes_via_reader::>(&seq), Ok(set.clone())); let seq = vec![2, 5, 4]; assert_eq!(from_bytes::>(&seq), Ok(set.clone())); + assert_eq!(from_bytes_via_reader::>(&seq), Ok(set.clone())); // Duplicate keys are just ok. let seq = vec![3, 5, 5, 4]; - assert_eq!(from_bytes::>(&seq), Ok(set)); + assert_eq!(from_bytes::>(&seq).as_ref(), Ok(&set)); + assert_eq!(from_bytes_via_reader::>(&seq), Ok(set)); } #[test] @@ -457,6 +524,13 @@ fn cow() { Message::M1(b) => assert_eq!(b.into_owned(), large_object), _ => panic!(), } + + let deserialized: Message<'static> = from_bytes_via_reader(&serialized).unwrap(); + + match deserialized { + Message::M1(b) => assert_eq!(b.into_owned(), large_object), + _ => panic!(), + } } // M2 @@ -464,6 +538,12 @@ fn cow() { let serialized = to_bytes(&Message::M2(Cow::Borrowed(&large_map))).unwrap(); let deserialized: Message<'static> = from_bytes(&serialized).unwrap(); + match deserialized { + Message::M2(b) => assert_eq!(b.into_owned(), large_map), + _ => panic!(), + } + let deserialized: Message<'static> = from_bytes_via_reader(&serialized).unwrap(); + match deserialized { Message::M2(b) => assert_eq!(b.into_owned(), large_map), _ => panic!(), @@ -480,6 +560,9 @@ fn strbox() { let deserialized: Cow<'static, String> = from_bytes(&serialized).unwrap(); let stringx: String = deserialized.into_owned(); assert_eq!(strx, stringx); + let deserialized: Cow<'static, String> = from_bytes_via_reader(&serialized).unwrap(); + let stringx: String = deserialized.into_owned(); + assert_eq!(strx, stringx); } #[test] @@ -495,6 +578,14 @@ fn slicebox() { } let vecx: Vec = deserialized.into_owned(); assert_eq!(slice, vecx[..]); + + let deserialized: Cow<'static, Vec> = from_bytes_via_reader(&serialized).unwrap(); + { + let sb: &[u32] = &deserialized; + assert_eq!(slice, sb); + } + let vecx: Vec = deserialized.into_owned(); + assert_eq!(slice, vecx[..]); } #[test] @@ -505,6 +596,9 @@ fn path_buf() { let encoded = to_bytes(&path).unwrap(); let decoded: PathBuf = from_bytes(&encoded).unwrap(); assert!(path.to_str() == decoded.to_str()); + + let decoded: PathBuf = from_bytes_via_reader(&encoded).unwrap(); + assert!(path.to_str() == decoded.to_str()); } #[derive(Arbitrary, Debug, Deserialize, Serialize, PartialEq)] @@ -567,6 +661,9 @@ fn serde_known_vector() { // make sure we can deserialize the test vector into expected struct let deserialized_foo: Foo = from_bytes(&test_vector).unwrap(); assert_eq!(f, deserialized_foo); + + let deserialized_foo: Foo = from_bytes_via_reader(&test_vector).unwrap(); + assert_eq!(f, deserialized_foo); } #[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] @@ -618,10 +715,12 @@ fn test_recursion_limit() { ] ); assert_eq!(from_bytes::>(&b1).unwrap(), l1); + assert_eq!(from_bytes_via_reader::>(&b1).unwrap(), l1); let l2 = List::integers(MAX_CONTAINER_DEPTH - 1); let b2 = to_bytes(&l2).unwrap(); assert_eq!(from_bytes::>(&b2).unwrap(), l2); + assert_eq!(from_bytes_via_reader::>(&b2).unwrap(), l2); let l3 = List::integers(MAX_CONTAINER_DEPTH); assert_eq!( to_bytes(&l3), @@ -633,12 +732,20 @@ fn test_recursion_limit() { from_bytes::>(&b3), Err(Error::ExceededContainerDepthLimit("List")) ); + assert_eq!( + from_bytes_via_reader::>(&b3), + Err(Error::ExceededContainerDepthLimit("List")) + ); let b2_pair = to_bytes(&(&l2, &l2)).unwrap(); assert_eq!( from_bytes::<(List<_>, List<_>)>(&b2_pair).unwrap(), (l2.clone(), l2.clone()) ); + assert_eq!( + from_bytes_via_reader::<(List<_>, List<_>)>(&b2_pair).unwrap(), + (l2.clone(), l2.clone()) + ); assert_eq!( to_bytes(&(&l2, &l3)), Err(Error::ExceededContainerDepthLimit("List")) @@ -663,10 +770,12 @@ fn test_recursion_limit_enum() { let b1 = to_bytes(&l1).unwrap(); assert_eq!(b1, vec![0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0],); assert_eq!(from_bytes::>(&b1).unwrap(), l1); + assert_eq!(from_bytes_via_reader::>(&b1).unwrap(), l1); let l2 = List::repeat(MAX_CONTAINER_DEPTH - 2, EnumA::ValueA); let b2 = to_bytes(&l2).unwrap(); assert_eq!(from_bytes::>(&b2).unwrap(), l2); + assert_eq!(from_bytes_via_reader::>(&b2).unwrap(), l2); let l3 = List::repeat(MAX_CONTAINER_DEPTH - 1, EnumA::ValueA); assert_eq!( @@ -679,4 +788,8 @@ fn test_recursion_limit_enum() { from_bytes::>(&b3), Err(Error::ExceededContainerDepthLimit("EnumA")) ); + assert_eq!( + from_bytes_via_reader::>(&b3), + Err(Error::ExceededContainerDepthLimit("EnumA")) + ); } From 902142b17dda751b9af5d438eaef57b6136a8a39 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 13:22:28 -0500 Subject: [PATCH 5/8] Cleanup interface --- src/de/mod.rs | 83 +++++++++++++++++++++++++++++---------------------- src/lib.rs | 2 +- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index ad8de46..9d4f028 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -54,52 +54,63 @@ where } /// Deserialize a type from an implementation of [`Read`]. -pub fn from_reader(mut reader: &mut impl Read) -> Result +pub fn from_reader(reader: &mut impl Read) -> Result where T: DeserializeOwned, { - let mut deserializer = Deserializer::from_reader(&mut reader, crate::MAX_CONTAINER_DEPTH); + let mut deserializer = Deserializer::from_reader(reader, crate::MAX_CONTAINER_DEPTH); T::deserialize(&mut deserializer) } +/// Deserialize a type from an implementation of [`Read`] using the provided seed +pub fn from_reader_seed( + seed: T, + reader: &mut impl Read, +) -> Result<>::Value> +where + for<'a> T: DeserializeSeed<'a>, +{ + let mut deserializer = Deserializer::from_reader(reader, crate::MAX_CONTAINER_DEPTH); + seed.deserialize(&mut deserializer) +} + /// Deserialization implementation for BCS -struct Deserializer<'de, R> { +struct Deserializer { input: R, max_remaining_depth: usize, - _phantom: std::marker::PhantomData<&'de ()>, } -impl<'de, R: Read> Deserializer<'de, TeeReader<&'de mut R>> { +impl<'de, R: Read> Deserializer> { fn from_reader(input: &'de mut R, max_remaining_depth: usize) -> Self { Deserializer { input: TeeReader::new(input), max_remaining_depth, - _phantom: std::marker::PhantomData, } } } -impl<'de> Deserializer<'de, &'de [u8]> { +impl<'de> Deserializer<&'de [u8]> { /// Creates a new `Deserializer` which will be deserializing the provided /// input. fn new(input: &'de [u8], max_remaining_depth: usize) -> Self { Deserializer { input, max_remaining_depth, - _phantom: std::marker::PhantomData, } } } /// A reader that can optionally capture all bytes from an underlying [`Read`]er -struct TeeReader { - reader: R, +struct TeeReader<'de, R> { + /// the underlying reader + reader: &'de mut R, + /// If set, all bytes read from the underlying reader will be duplicated here capture_buffer: Option>, } -impl TeeReader { +impl<'de, R> TeeReader<'de, R> { /// Wrapse the provided reader in a new [`TeeReader`]. - pub fn new(reader: R) -> Self { + pub fn new(reader: &'de mut R) -> Self { Self { reader, capture_buffer: Default::default(), @@ -107,7 +118,7 @@ impl TeeReader { } } -impl Read for TeeReader { +impl<'de, R: Read> Read for TeeReader<'de, R> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { let bytes_read = self.reader.read(buf)?; if let Some(ref mut buffer) = self.capture_buffer { @@ -211,7 +222,7 @@ trait BcsDeserializer<'de> { } } -impl<'de, R: Read> Deserializer<'de, TeeReader> { +impl<'de, R: Read> Deserializer> { fn parse_vec(&mut self) -> Result> { let len = self.parse_length()?; let mut output = vec![0; len]; @@ -225,7 +236,7 @@ impl<'de, R: Read> Deserializer<'de, TeeReader> { } } -impl<'de, R: Read> BcsDeserializer<'de> for Deserializer<'de, TeeReader> { +impl<'de, R: Read> BcsDeserializer<'de> for Deserializer> { type MaybeBorrowedBytes = Vec; fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { @@ -257,7 +268,7 @@ impl<'de, R: Read> BcsDeserializer<'de> for Deserializer<'de, TeeReader> { } } -impl<'de> BcsDeserializer<'de> for Deserializer<'de, &'de [u8]> { +impl<'de> BcsDeserializer<'de> for Deserializer<&'de [u8]> { type MaybeBorrowedBytes = &'de [u8]; fn next(&mut self) -> Result { let byte = self.peek()?; @@ -298,7 +309,7 @@ impl<'de> BcsDeserializer<'de> for Deserializer<'de, &'de [u8]> { } } -impl<'de> Deserializer<'de, &'de [u8]> { +impl<'de> Deserializer<&'de [u8]> { fn peek(&mut self) -> Result { self.input.first().copied().ok_or(Error::Eof) } @@ -327,7 +338,7 @@ impl<'de> Deserializer<'de, &'de [u8]> { } } -impl<'de, R> Deserializer<'de, R> { +impl<'de, R> Deserializer { fn enter_named_container(&mut self, name: &'static str) -> Result<()> { if self.max_remaining_depth == 0 { return Err(Error::ExceededContainerDepthLimit(name)); @@ -341,9 +352,9 @@ impl<'de, R> Deserializer<'de, R> { } } -impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer where - Deserializer<'de, R>: BcsDeserializer<'de>, + Deserializer: BcsDeserializer<'de>, { type Error = Error; @@ -611,20 +622,20 @@ where } } -struct SeqDeserializer<'a, 'de: 'a, R> { - de: &'a mut Deserializer<'de, R>, +struct SeqDeserializer<'a, R> { + de: &'a mut Deserializer, remaining: usize, } #[allow(clippy::needless_borrow)] -impl<'a, 'de, R> SeqDeserializer<'a, 'de, R> { - fn new(de: &'a mut Deserializer<'de, R>, remaining: usize) -> Self { +impl<'a, R> SeqDeserializer<'a, R> { + fn new(de: &'a mut Deserializer, remaining: usize) -> Self { Self { de, remaining } } } -impl<'de, 'a, R> de::SeqAccess<'de> for SeqDeserializer<'a, 'de, R> +impl<'a, 'de, R> de::SeqAccess<'de> for SeqDeserializer<'a, R> where - Deserializer<'de, R>: BcsDeserializer<'de>, + Deserializer: BcsDeserializer<'de>, { type Error = Error; @@ -645,14 +656,14 @@ where } } -struct MapDeserializer<'a, 'de: 'a, R, B> { - de: &'a mut Deserializer<'de, R>, +struct MapDeserializer<'a, R, B> { + de: &'a mut Deserializer, remaining: usize, previous_key_bytes: Option, } -impl<'a, 'de, R, B> MapDeserializer<'a, 'de, R, B> { - fn new(de: &'a mut Deserializer<'de, R>, remaining: usize) -> Self { +impl<'a, R, B> MapDeserializer<'a, R, B> { + fn new(de: &'a mut Deserializer, remaining: usize) -> Self { Self { de, remaining, @@ -661,9 +672,9 @@ impl<'a, 'de, R, B> MapDeserializer<'a, 'de, R, B> { } } -impl<'de, 'a, R, B: AsRef<[u8]>> de::MapAccess<'de> for MapDeserializer<'a, 'de, R, B> +impl<'de, 'a, R, B: AsRef<[u8]>> de::MapAccess<'de> for MapDeserializer<'a, R, B> where - Deserializer<'de, R>: BcsDeserializer<'de, MaybeBorrowedBytes = B>, + Deserializer: BcsDeserializer<'de, MaybeBorrowedBytes = B>, { type Error = Error; @@ -699,9 +710,9 @@ where } } -impl<'de, 'a, R> de::EnumAccess<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R> de::EnumAccess<'de> for &'a mut Deserializer where - Deserializer<'de, R>: BcsDeserializer<'de>, + Deserializer: BcsDeserializer<'de>, { type Error = Error; type Variant = Self; @@ -716,9 +727,9 @@ where } } -impl<'de, 'a, R> de::VariantAccess<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R> de::VariantAccess<'de> for &'a mut Deserializer where - Deserializer<'de, R>: BcsDeserializer<'de>, + Deserializer: BcsDeserializer<'de>, { type Error = Error; diff --git a/src/lib.rs b/src/lib.rs index 3f4d452..6e4feec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -314,6 +314,6 @@ pub const MAX_SEQUENCE_LENGTH: usize = (1 << 31) - 1; /// Maximal allowed depth of BCS data, counting only structs and enums. pub const MAX_CONTAINER_DEPTH: usize = 500; -pub use de::{from_bytes, from_bytes_seed, from_reader}; +pub use de::{from_bytes, from_bytes_seed, from_reader, from_reader_seed}; pub use error::{Error, Result}; pub use ser::{is_human_readable, serialize_into, serialized_size, to_bytes}; From 0ba5d805e4f5059f34615bc676fbf6d617822692 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 13:23:10 -0500 Subject: [PATCH 6/8] Revert to original layout --- src/{de/mod.rs => de.rs} | 0 src/de/from_bytes.rs | 518 --------------------------------------- src/de/from_reader.rs | 475 ----------------------------------- 3 files changed, 993 deletions(-) rename src/{de/mod.rs => de.rs} (100%) delete mode 100644 src/de/from_bytes.rs delete mode 100644 src/de/from_reader.rs diff --git a/src/de/mod.rs b/src/de.rs similarity index 100% rename from src/de/mod.rs rename to src/de.rs diff --git a/src/de/from_bytes.rs b/src/de/from_bytes.rs deleted file mode 100644 index 133ef07..0000000 --- a/src/de/from_bytes.rs +++ /dev/null @@ -1,518 +0,0 @@ -use crate::error::{Error, Result}; -use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, Visitor}; - -use super::BcsDeserializer; - -/// Deserializes a `&[u8]` into a type. -/// -/// This function will attempt to interpret `bytes` as the BCS serialized form of `T` and -/// deserialize `T` from `bytes`. -/// -/// # Examples -/// -/// ``` -/// use bcs::from_bytes; -/// use serde::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct Ip([u8; 4]); -/// -/// #[derive(Deserialize)] -/// struct Port(u16); -/// -/// #[derive(Deserialize)] -/// struct SocketAddr { -/// ip: Ip, -/// port: Port, -/// } -/// -/// let bytes = vec![0x7f, 0x00, 0x00, 0x01, 0x41, 0x1f]; -/// let socket_addr: SocketAddr = from_bytes(&bytes).unwrap(); -/// -/// assert_eq!(socket_addr.ip.0, [127, 0, 0, 1]); -/// assert_eq!(socket_addr.port.0, 8001); -/// ``` -pub fn from_bytes<'a, T>(bytes: &'a [u8]) -> Result -where - T: Deserialize<'a>, -{ - let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); - let t = T::deserialize(&mut deserializer)?; - deserializer.end().map(move |_| t) -} - -/// Perform a stateful deserialization from a `&[u8]` using the provided `seed`. -pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result -where - T: DeserializeSeed<'a>, -{ - let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); - let t = seed.deserialize(&mut deserializer)?; - deserializer.end().map(move |_| t) -} - -impl<'de> BcsDeserializer for Deserializer<'de> { - fn next(&mut self) -> Result { - let byte = self.peek()?; - self.input = &self.input[1..]; - Ok(byte) - } - - fn max_remaining_depth(&mut self) -> usize { - self.max_remaining_depth - } - - fn max_remaining_depth_mut(&mut self) -> &mut usize { - &mut self.max_remaining_depth - } - - fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { - for byte in slice { - *byte = self.next()?; - } - Ok(()) - } -} - -/// Deserialization implementation for BCS -struct Deserializer<'de> { - input: &'de [u8], - max_remaining_depth: usize, -} - -impl<'de> Deserializer<'de> { - /// Creates a new `Deserializer` which will be deserializing the provided - /// input. - fn new(input: &'de [u8], max_remaining_depth: usize) -> Self { - Deserializer { - input, - max_remaining_depth, - } - } - - /// The `Deserializer::end` method should be called after a type has been - /// fully deserialized. This allows the `Deserializer` to validate that - /// the there are no more bytes remaining in the input stream. - fn end(&mut self) -> Result<()> { - if self.input.is_empty() { - Ok(()) - } else { - Err(Error::RemainingInput) - } - } -} - -impl<'de> Deserializer<'de> { - fn peek(&mut self) -> Result { - self.input.first().copied().ok_or(Error::Eof) - } - - fn parse_string_borrowed(&mut self) -> Result<&'de str> { - let slice = self.parse_bytes_borrowed()?; - std::str::from_utf8(slice).map_err(|_| Error::Utf8) - } - - fn parse_bytes_borrowed(&mut self) -> Result<&'de [u8]> { - let len = self.parse_length()?; - let slice = self.input.get(..len).ok_or(Error::Eof)?; - self.input = &self.input[len..]; - Ok(slice) - } -} - -impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { - type Error = Error; - - // BCS is not a self-describing format so we can't implement `deserialize_any` - fn deserialize_any(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_any")) - } - - fn deserialize_bool(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_bool(self.parse_bool()?) - } - - fn deserialize_i8(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i8(self.parse_u8()? as i8) - } - - fn deserialize_i16(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i16(self.parse_u16()? as i16) - } - - fn deserialize_i32(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i32(self.parse_u32()? as i32) - } - - fn deserialize_i64(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i64(self.parse_u64()? as i64) - } - - fn deserialize_i128(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i128(self.parse_u128()? as i128) - } - - fn deserialize_u8(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u8(self.parse_u8()?) - } - - fn deserialize_u16(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u16(self.parse_u16()?) - } - - fn deserialize_u32(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u32(self.parse_u32()?) - } - - fn deserialize_u64(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u64(self.parse_u64()?) - } - - fn deserialize_u128(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u128(self.parse_u128()?) - } - - fn deserialize_f32(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_f32")) - } - - fn deserialize_f64(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_f64")) - } - - fn deserialize_char(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_char")) - } - - fn deserialize_str(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_borrowed_str(self.parse_string_borrowed()?) - } - - fn deserialize_string(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_str(visitor) - } - - fn deserialize_bytes(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_borrowed_bytes(self.parse_bytes_borrowed()?) - } - - fn deserialize_byte_buf(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_bytes(visitor) - } - - fn deserialize_option(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let byte = self.next()?; - - match byte { - 0 => visitor.visit_none(), - 1 => visitor.visit_some(self), - _ => Err(Error::ExpectedOption), - } - } - - fn deserialize_unit(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = self.deserialize_unit(visitor); - self.leave_named_container(); - r - } - - fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_newtype_struct(&mut *self); - self.leave_named_container(); - r - } - #[allow(clippy::needless_borrow)] - fn deserialize_seq(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let len = self.parse_length()?; - visitor.visit_seq(SeqDeserializer::new(&mut self, len)) - } - #[allow(clippy::needless_borrow)] - fn deserialize_tuple(mut self, len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_seq(SeqDeserializer::new(&mut self, len)) - } - #[allow(clippy::needless_borrow)] - fn deserialize_tuple_struct( - mut self, - name: &'static str, - len: usize, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_seq(SeqDeserializer::new(&mut self, len)); - self.leave_named_container(); - r - } - #[allow(clippy::needless_borrow)] - fn deserialize_map(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let len = self.parse_length()?; - visitor.visit_map(MapDeserializer::new(&mut self, len)) - } - #[allow(clippy::needless_borrow)] - fn deserialize_struct( - mut self, - name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_seq(SeqDeserializer::new(&mut self, fields.len())); - self.leave_named_container(); - r - } - - fn deserialize_enum( - self, - name: &'static str, - _variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_enum(&mut *self); - self.leave_named_container(); - r - } - - // BCS does not utilize identifiers, so throw them away - fn deserialize_identifier(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_bytes(_visitor) - } - - // BCS is not a self-describing format so we can't implement `deserialize_ignored_any` - fn deserialize_ignored_any(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_ignored_any")) - } - - // BCS is not a human readable format - fn is_human_readable(&self) -> bool { - false - } -} - -struct SeqDeserializer<'a, 'de: 'a> { - de: &'a mut Deserializer<'de>, - remaining: usize, -} -#[allow(clippy::needless_borrow)] -impl<'a, 'de> SeqDeserializer<'a, 'de> { - fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self { - Self { de, remaining } - } -} - -impl<'de, 'a> de::SeqAccess<'de> for SeqDeserializer<'a, 'de> { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'de>, - { - if self.remaining == 0 { - Ok(None) - } else { - self.remaining -= 1; - seed.deserialize(&mut *self.de).map(Some) - } - } - - fn size_hint(&self) -> Option { - Some(self.remaining) - } -} - -struct MapDeserializer<'a, 'de: 'a> { - de: &'a mut Deserializer<'de>, - remaining: usize, - previous_key_bytes: Option<&'a [u8]>, -} - -impl<'a, 'de> MapDeserializer<'a, 'de> { - fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self { - Self { - de, - remaining, - previous_key_bytes: None, - } - } -} - -impl<'de, 'a> de::MapAccess<'de> for MapDeserializer<'a, 'de> { - type Error = Error; - - fn next_key_seed(&mut self, seed: K) -> Result> - where - K: DeserializeSeed<'de>, - { - match self.remaining.checked_sub(1) { - None => Ok(None), - Some(remaining) => { - let previous_input_slice = self.de.input; - let key_value = seed.deserialize(&mut *self.de)?; - let key_len = previous_input_slice - .len() - .saturating_sub(self.de.input.len()); - let key_bytes = &previous_input_slice[..key_len]; - if let Some(previous_key_bytes) = self.previous_key_bytes { - if previous_key_bytes >= key_bytes { - return Err(Error::NonCanonicalMap); - } - } - self.remaining = remaining; - self.previous_key_bytes = Some(key_bytes); - Ok(Some(key_value)) - } - } - } - - fn next_value_seed(&mut self, seed: V) -> Result - where - V: DeserializeSeed<'de>, - { - seed.deserialize(&mut *self.de) - } - - fn size_hint(&self) -> Option { - Some(self.remaining) - } -} - -impl<'de, 'a> de::EnumAccess<'de> for &'a mut Deserializer<'de> { - type Error = Error; - type Variant = Self; - - fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> - where - V: DeserializeSeed<'de>, - { - let variant_index = self.parse_u32_from_uleb128()?; - let result: Result = seed.deserialize(variant_index.into_deserializer()); - Ok((result?, self)) - } -} - -impl<'de, 'a> de::VariantAccess<'de> for &'a mut Deserializer<'de> { - type Error = Error; - - fn unit_variant(self) -> Result<()> { - Ok(()) - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: DeserializeSeed<'de>, - { - seed.deserialize(self) - } - - fn tuple_variant(self, len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - de::Deserializer::deserialize_tuple(self, len, visitor) - } - - fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result - where - V: Visitor<'de>, - { - de::Deserializer::deserialize_tuple(self, fields.len(), visitor) - } -} diff --git a/src/de/from_reader.rs b/src/de/from_reader.rs deleted file mode 100644 index e550cbc..0000000 --- a/src/de/from_reader.rs +++ /dev/null @@ -1,475 +0,0 @@ -use crate::error::{Error, Result}; -use serde::de::{self, DeserializeSeed, IntoDeserializer, Visitor}; -use std::io::Read; - -use super::BcsDeserializer; - -/// Deserializes BCS from a [`std::io::Read`]er -pub struct DeserializeReader<'de, R> { - reader: TeeReader<'de, R>, - max_remaining_depth: usize, -} - -impl<'de, R: Read> DeserializeReader<'de, R> { - /// Wraps the provided reader in a new [`DeserializeReader`] - fn new(reader: &'de mut R, max_remaining_depth: usize) -> Self { - DeserializeReader { - reader: TeeReader::new(reader), - max_remaining_depth, - } - } -} - -impl<'de, R: Read> BcsDeserializer for DeserializeReader<'de, R> { - fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { - Ok(self.reader.read_exact(&mut slice[..])?) - } - - fn max_remaining_depth(&mut self) -> usize { - self.max_remaining_depth - } - - fn max_remaining_depth_mut(&mut self) -> &mut usize { - &mut self.max_remaining_depth - } -} - -impl<'de, R: Read> DeserializeReader<'de, R> { - /// Parse a vector of bytes from the reader - fn parse_vec(&mut self) -> Result> { - let len = self.parse_length()?; - let mut output = vec![0; len]; - self.fill_slice(&mut output)?; - Ok(output) - } - - /// Parse a String from the reader - fn parse_string(&mut self) -> Result { - let bytes = self.parse_vec()?; - String::from_utf8(bytes).map_err(|_| Error::Utf8) - } -} - -impl<'de, 'a, R: Read> de::Deserializer<'de> for &'a mut DeserializeReader<'de, R> { - type Error = Error; - - // BCS is not a self-describing format so we can't implement `deserialize_any` - fn deserialize_any(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_any")) - } - - fn deserialize_bool(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_bool(self.parse_bool()?) - } - - fn deserialize_i8(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i8(self.parse_u8()? as i8) - } - - fn deserialize_i16(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i16(self.parse_u16()? as i16) - } - - fn deserialize_i32(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i32(self.parse_u32()? as i32) - } - - fn deserialize_i64(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i64(self.parse_u64()? as i64) - } - - fn deserialize_i128(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_i128(self.parse_u128()? as i128) - } - - fn deserialize_u8(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u8(self.parse_u8()?) - } - - fn deserialize_u16(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u16(self.parse_u16()?) - } - - fn deserialize_u32(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u32(self.parse_u32()?) - } - - fn deserialize_u64(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u64(self.parse_u64()?) - } - - fn deserialize_u128(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_u128(self.parse_u128()?) - } - - fn deserialize_f32(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_f32")) - } - - fn deserialize_f64(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_f64")) - } - - fn deserialize_char(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_char")) - } - - fn deserialize_str(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_string(self.parse_string()?) - } - - fn deserialize_string(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_str(visitor) - } - - fn deserialize_bytes(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_byte_buf(self.parse_vec()?) - } - - fn deserialize_byte_buf(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_bytes(visitor) - } - - fn deserialize_option(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let byte = self.next()?; - - match byte { - 0 => visitor.visit_none(), - 1 => visitor.visit_some(self), - _ => Err(Error::ExpectedOption), - } - } - - fn deserialize_unit(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = self.deserialize_unit(visitor); - self.leave_named_container(); - r - } - - fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_newtype_struct(&mut *self); - self.leave_named_container(); - r - } - #[allow(clippy::needless_borrow)] - fn deserialize_seq(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let len = self.parse_length()?; - visitor.visit_seq(SeqDeserializer::new(&mut self, len)) - } - #[allow(clippy::needless_borrow)] - fn deserialize_tuple(mut self, len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_seq(SeqDeserializer::new(&mut self, len)) - } - #[allow(clippy::needless_borrow)] - fn deserialize_tuple_struct( - mut self, - name: &'static str, - len: usize, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_seq(SeqDeserializer::new(&mut self, len)); - self.leave_named_container(); - r - } - #[allow(clippy::needless_borrow)] - fn deserialize_map(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let len = self.parse_length()?; - visitor.visit_map(MapDeserializer::new(&mut self, len)) - } - #[allow(clippy::needless_borrow)] - fn deserialize_struct( - mut self, - name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_seq(SeqDeserializer::new(&mut self, fields.len())); - self.leave_named_container(); - r - } - - fn deserialize_enum( - self, - name: &'static str, - _variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.enter_named_container(name)?; - let r = visitor.visit_enum(&mut *self); - self.leave_named_container(); - r - } - - // BCS does not utilize identifiers, so throw them away - fn deserialize_identifier(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_bytes(_visitor) - } - - // BCS is not a self-describing format so we can't implement `deserialize_ignored_any` - fn deserialize_ignored_any(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::NotSupported("deserialize_ignored_any")) - } - - // BCS is not a human readable format - fn is_human_readable(&self) -> bool { - false - } -} - -struct SeqDeserializer<'a, 'de: 'a, R> { - de: &'a mut DeserializeReader<'de, R>, - remaining: usize, -} -#[allow(clippy::needless_borrow)] -impl<'a, 'de: 'a, R> SeqDeserializer<'a, 'de, R> { - fn new(de: &'a mut DeserializeReader<'de, R>, remaining: usize) -> Self { - Self { de, remaining } - } -} - -impl<'de, 'a, R: Read> de::SeqAccess<'de> for SeqDeserializer<'a, 'de, R> { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'de>, - { - if self.remaining == 0 { - Ok(None) - } else { - self.remaining -= 1; - seed.deserialize(&mut *self.de).map(Some) - } - } - - fn size_hint(&self) -> Option { - Some(self.remaining) - } -} - -/// A reader that can optionally capture all bytes from an underlying [`Read`]er -pub struct TeeReader<'a, R> { - reader: &'a mut R, - capture_buffer: Option>, -} - -impl<'a, R> TeeReader<'a, R> { - /// Wrapse the provided reader in a new [`TeeReader`]. - pub fn new(reader: &'a mut R) -> Self { - Self { - reader, - capture_buffer: Default::default(), - } - } -} - -impl<'a, R: Read> Read for TeeReader<'a, R> { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let bytes_read = self.reader.read(buf)?; - if let Some(ref mut buffer) = self.capture_buffer { - buffer.extend_from_slice(&buf[..bytes_read]); - } - Ok(bytes_read) - } -} - -struct MapDeserializer<'a, 'de: 'a, R> { - de: &'a mut DeserializeReader<'de, R>, - remaining: usize, - previous_key_bytes: Option>, -} - -impl<'a, 'de, R: Read> MapDeserializer<'a, 'de, R> { - fn new(de: &'a mut DeserializeReader<'de, R>, remaining: usize) -> Self { - Self { - de, - remaining, - previous_key_bytes: None, - } - } -} - -impl<'de, 'a, R: Read> de::MapAccess<'de> for MapDeserializer<'a, 'de, R> -where - 'de: 'a, -{ - type Error = Error; - - fn next_key_seed(&mut self, seed: K) -> Result> - where - K: DeserializeSeed<'de>, - { - match self.remaining.checked_sub(1) { - None => Ok(None), - Some(remaining) => { - self.de.reader.capture_buffer = Some(Vec::new()); - let key_value = seed.deserialize(&mut *self.de)?; - let key_bytes = self.de.reader.capture_buffer.take().unwrap(); - - if let Some(ref previous_key_bytes) = self.previous_key_bytes { - if previous_key_bytes.as_slice() >= key_bytes.as_slice() { - return Err(Error::NonCanonicalMap); - } - } - self.remaining = remaining; - self.previous_key_bytes = Some(key_bytes); - Ok(Some(key_value)) - } - } - } - - fn next_value_seed(&mut self, seed: V) -> Result - where - V: DeserializeSeed<'de>, - { - seed.deserialize(&mut *self.de) - } - - fn size_hint(&self) -> Option { - Some(self.remaining) - } -} - -impl<'a, 'de: 'a, R: Read> de::EnumAccess<'de> for &'a mut DeserializeReader<'de, R> { - type Error = Error; - type Variant = Self; - - fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> - where - V: DeserializeSeed<'de>, - { - let variant_index = self.parse_u32_from_uleb128()?; - let result: Result = seed.deserialize(variant_index.into_deserializer()); - Ok((result?, self)) - } -} - -impl<'a, 'de: 'a, R: Read> de::VariantAccess<'de> for &'a mut DeserializeReader<'de, R> { - type Error = Error; - - fn unit_variant(self) -> Result<()> { - Ok(()) - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: DeserializeSeed<'de>, - { - seed.deserialize(self) - } - - fn tuple_variant(self, len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - de::Deserializer::deserialize_tuple(self, len, visitor) - } - - fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result - where - V: Visitor<'de>, - { - de::Deserializer::deserialize_tuple(self, fields.len(), visitor) - } -} From e4f1861509a996408baa540cda1f88c290214ce1 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 14:28:05 -0500 Subject: [PATCH 7/8] Check that reader empty --- src/de.rs | 35 ++++++++++++++++++++++++----------- tests/serde.rs | 4 ++++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/de.rs b/src/de.rs index 9d4f028..2c7f738 100644 --- a/src/de.rs +++ b/src/de.rs @@ -59,7 +59,8 @@ where T: DeserializeOwned, { let mut deserializer = Deserializer::from_reader(reader, crate::MAX_CONTAINER_DEPTH); - T::deserialize(&mut deserializer) + let t = T::deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) } /// Deserialize a type from an implementation of [`Read`] using the provided seed @@ -71,7 +72,8 @@ where for<'a> T: DeserializeSeed<'a>, { let mut deserializer = Deserializer::from_reader(reader, crate::MAX_CONTAINER_DEPTH); - seed.deserialize(&mut deserializer) + let t = seed.deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) } /// Deserialization implementation for BCS @@ -146,6 +148,11 @@ trait BcsDeserializer<'de> { seed: K, ) -> Result<(K::Value, Self::MaybeBorrowedBytes), Error>; + /// The `Deserializer::end` method should be called after a type has been + /// fully deserialized. This allows the `Deserializer` to validate that + /// the there are no more bytes remaining in the input stream. + fn end(&mut self) -> Result<()>; + fn parse_bool(&mut self) -> Result { let byte = self.next()?; @@ -266,6 +273,15 @@ impl<'de, R: Read> BcsDeserializer<'de> for Deserializer> { let key_bytes = self.input.capture_buffer.take().unwrap(); Ok((key_value, key_bytes)) } + + fn end(&mut self) -> Result<()> { + let mut byte = [0u8; 1]; + match self.input.read_exact(&mut byte) { + Ok(_) => Err(Error::RemainingInput), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()), + Err(e) => Err(e.into()), + } + } } impl<'de> BcsDeserializer<'de> for Deserializer<&'de [u8]> { @@ -307,16 +323,7 @@ impl<'de> BcsDeserializer<'de> for Deserializer<&'de [u8]> { let key_bytes = &previous_input_slice[..key_len]; Ok((key_value, key_bytes)) } -} - -impl<'de> Deserializer<&'de [u8]> { - fn peek(&mut self) -> Result { - self.input.first().copied().ok_or(Error::Eof) - } - /// The `Deserializer::end` method should be called after a type has been - /// fully deserialized. This allows the `Deserializer` to validate that - /// the there are no more bytes remaining in the input stream. fn end(&mut self) -> Result<()> { if self.input.is_empty() { Ok(()) @@ -324,6 +331,12 @@ impl<'de> Deserializer<&'de [u8]> { Err(Error::RemainingInput) } } +} + +impl<'de> Deserializer<&'de [u8]> { + fn peek(&mut self) -> Result { + self.input.first().copied().ok_or(Error::Eof) + } fn parse_bytes(&mut self) -> Result<&'de [u8]> { let len = self.parse_length()?; diff --git a/tests/serde.rs b/tests/serde.rs index 0d49e0b..caa26e9 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -463,6 +463,10 @@ fn by_default_btreesets_are_serialized_as_sequences() { fn leftover_bytes() { let seq = vec![5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; // 5 extra elements assert_eq!(from_bytes::>(&seq), Err(Error::RemainingInput)); + assert_eq!( + from_bytes_via_reader::>(&seq), + Err(Error::RemainingInput) + ); } #[test] From 387f7fdca35045f3185dbdabbe2cc7e55da76834 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Mon, 11 Sep 2023 09:03:00 -0500 Subject: [PATCH 8/8] remove unused lifetime --- src/de.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/de.rs b/src/de.rs index 2c7f738..2588823 100644 --- a/src/de.rs +++ b/src/de.rs @@ -351,7 +351,7 @@ impl<'de> Deserializer<&'de [u8]> { } } -impl<'de, R> Deserializer { +impl Deserializer { fn enter_named_container(&mut self, name: &'static str) -> Result<()> { if self.max_remaining_depth == 0 { return Err(Error::ExceededContainerDepthLimit(name));