From ab944d114c3ec3e7e118eab4c9b27a2cd4e39f66 Mon Sep 17 00:00:00 2001 From: Steve Fan <29133953+stevefan1999-personal@users.noreply.github.com> Date: Fri, 24 Jun 2022 14:50:18 +0000 Subject: [PATCH 1/2] update the dependencies to latest version --- Cargo.toml | 15 +++---- src/base64.rs | 107 -------------------------------------------------- src/http.rs | 21 +++++----- src/lib.rs | 12 +++--- 4 files changed, 24 insertions(+), 131 deletions(-) delete mode 100644 src/base64.rs diff --git a/Cargo.toml b/Cargo.toml index eb6cc53..ade4717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,17 +11,18 @@ categories = ["embedded", "no-std", "network-programming"] readme = "README.md" [dependencies] -sha1 = "0.6" -heapless = "0.5" -byteorder = { version = "1.4", default-features = false } -httparse = { version = "1.4", default-features = false } -rand_core = "0.6" +sha1 = "0.10.1" +heapless = "0.7.14" +byteorder = { version = "1.4.3", default-features = false } +httparse = { version = "1.7.1", default-features = false } +rand_core = "0.6.3" +base64 = { version = "0.13.0", default-features = false } [dev-dependencies] -rand = "0.8.3" +rand = "0.8.5" # see readme for no_std support [features] default = ["std"] # default = [] -std = [] \ No newline at end of file +std = [] diff --git a/src/base64.rs b/src/base64.rs deleted file mode 100644 index 12f3391..0000000 --- a/src/base64.rs +++ /dev/null @@ -1,107 +0,0 @@ -// *************************************** BASE64 ENCODE ****************************************** -// The base64_encode function below was adapted from the rust-base64 library -// https://github.com/alicemaz/rust-base64 -// The MIT License (MIT) -// Copyright (c) 2015 Alice Maz, 2019 David Haig -// Adapted for no_std specifically for MIME (Standard) flavoured base64 encoding -// ************************************************************************************************ - -pub const BASE64_ENCODE_TABLE: &[u8; 64] = &[ - 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, - 89, 90, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, - 115, 116, 117, 118, 119, 120, 121, 122, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 43, 47, -]; - -pub fn encode(input: &[u8], output: &mut [u8]) -> usize { - let encode_table: &[u8; 64] = BASE64_ENCODE_TABLE; - let mut input_index: usize = 0; - let mut output_index = 0; - const LOW_SIX_BITS_U8: u8 = 0x3F; - let rem = input.len() % 3; - let start_of_rem = input.len() - rem; - - while input_index < start_of_rem { - let input_chunk = &input[input_index..(input_index + 3)]; - let output_chunk = &mut output[output_index..(output_index + 4)]; - - output_chunk[0] = encode_table[(input_chunk[0] >> 2) as usize]; - output_chunk[1] = - encode_table[((input_chunk[0] << 4 | input_chunk[1] >> 4) & LOW_SIX_BITS_U8) as usize]; - output_chunk[2] = - encode_table[((input_chunk[1] << 2 | input_chunk[2] >> 6) & LOW_SIX_BITS_U8) as usize]; - output_chunk[3] = encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize]; - - input_index += 3; - output_index += 4; - } - - if rem == 2 { - output[output_index] = encode_table[(input[start_of_rem] >> 2) as usize]; - output[output_index + 1] = encode_table[((input[start_of_rem] << 4 - | input[start_of_rem + 1] >> 4) - & LOW_SIX_BITS_U8) as usize]; - output[output_index + 2] = - encode_table[((input[start_of_rem + 1] << 2) & LOW_SIX_BITS_U8) as usize]; - output_index += 3; - } else if rem == 1 { - output[output_index] = encode_table[(input[start_of_rem] >> 2) as usize]; - output[output_index + 1] = - encode_table[((input[start_of_rem] << 4) & LOW_SIX_BITS_U8) as usize]; - output_index += 2; - } - - // add padding - let rem = input.len() % 3; - for _ in 0..((3 - rem) % 3) { - output[output_index] = b'='; - output_index += 1; - } - - output_index -} - -// ************************************************************************************************ -// **************************************** TESTS ************************************************* -// ************************************************************************************************ - -#[cfg(test)] -mod tests { - extern crate std; - - // ASCII values A-Za-z0-9+/ - pub const STANDARD_ENCODE: &'static [u8; 64] = &[ - 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, - 88, 89, 90, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, - 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, - 43, 47, - ]; - - #[test] - fn base64_encode_test() { - let input: &[u8] = &[0; 20]; - let output: &mut [u8] = &mut [0; 100]; - let encode_table: &[u8; 64] = STANDARD_ENCODE; - let mut input_index: usize = 0; - let mut output_index = 0; - - const LOW_SIX_BITS_U8: u8 = 0x3F; - - let rem = input.len() % 3; - let start_of_rem = input.len() - rem; - - while input_index < start_of_rem { - let input_chunk = &input[input_index..(input_index + 3)]; - let output_chunk = &mut output[output_index..(output_index + 4)]; - - output_chunk[0] = encode_table[(input_chunk[0] >> 2) as usize]; - output_chunk[1] = encode_table - [((input_chunk[0] << 4 | input_chunk[1] >> 4) & LOW_SIX_BITS_U8) as usize]; - output_chunk[2] = encode_table - [((input_chunk[1] << 2 | input_chunk[2] >> 6) & LOW_SIX_BITS_U8) as usize]; - output_chunk[3] = encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize]; - - input_index += 3; - output_index += 4; - } - } -} diff --git a/src/http.rs b/src/http.rs index d1cc46d..7d60836 100644 --- a/src/http.rs +++ b/src/http.rs @@ -5,7 +5,7 @@ use heapless::{String, Vec}; /// Websocket details extracted from the http header pub struct WebSocketContext { /// The list of sub protocols is restricted to a maximum of 3 - pub sec_websocket_protocol_list: Vec, + pub sec_websocket_protocol_list: Vec, /// The websocket key user to build the accept string to complete the opening handshake pub sec_websocket_key: WebSocketKey, } @@ -40,7 +40,7 @@ pub struct WebSocketContext { pub fn read_http_header<'a>( headers: impl Iterator, ) -> Result> { - let mut sec_websocket_protocol_list: Vec, U3> = Vec::new(); + let mut sec_websocket_protocol_list: Vec, 3> = Vec::new(); let mut is_websocket_request = false; let mut sec_websocket_key = String::new(); @@ -130,13 +130,13 @@ pub fn build_connect_handshake_request( rng: &mut impl RngCore, to: &mut [u8], ) -> Result<(usize, WebSocketKey)> { - let mut http_request: String = String::new(); + let mut http_request: String<1024> = String::new(); let mut key_as_base64: [u8; 24] = [0; 24]; let mut key: [u8; 16] = [0; 16]; rng.fill_bytes(&mut key); - base64::encode(&key, &mut key_as_base64); - let sec_websocket_key: String = String::from(str::from_utf8(&key_as_base64)?); + base64::encode_config_slice(&key, base64::STANDARD, &mut key_as_base64); + let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?); http_request.push_str("GET ")?; http_request.push_str(websocket_options.path)?; @@ -177,7 +177,7 @@ pub fn build_connect_handshake_response( sec_websocket_protocol: Option<&WebSocketSubProtocol>, to: &mut [u8], ) -> Result { - let mut http_response: String = String::new(); + let mut http_response: String<1024> = String::new(); http_response.push_str( "HTTP/1.1 101 Switching Protocols\r\n\ Connection: Upgrade\r\nUpgrade: websocket\r\n", @@ -204,13 +204,14 @@ pub fn build_connect_handshake_response( pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) -> Result<()> { // concatenate the key with a known websocket GUID (as per the spec) - let mut accept_string: String = String::new(); + let mut accept_string: String<64> = String::new(); accept_string.push_str(sec_websocket_key)?; accept_string.push_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")?; // calculate the base64 encoded sha1 hash of the accept string above - let sha1 = Sha1::from(&accept_string); - let input = sha1.digest().bytes(); - base64::encode(&input, output); // no need for slices since the output WILL be 28 bytes + let mut sha1 = Sha1::new(); + sha1.update(&accept_string); + let input = sha1.finalize(); + base64::encode_config_slice(&input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 583a7d6..37376b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,12 +15,10 @@ use byteorder::{BigEndian, ByteOrder}; use core::{cmp, result, str}; -use heapless::consts::{U1024, U24, U256, U3, U64}; use heapless::{String, Vec}; use rand_core::RngCore; -use sha1::Sha1; +use sha1::{Sha1, Digest}; -mod base64; mod http; pub mod random; pub use self::http::{read_http_header, WebSocketContext}; @@ -36,10 +34,10 @@ const MASK_KEY_LEN: usize = 4; pub type Result = result::Result; /// A fixed length 24-character string used to hold a websocket key for the opening handshake -pub type WebSocketKey = String; +pub type WebSocketKey = String<24>; /// A maximum sized 24-character string used to store a sub protocol (e.g. `chat`) -pub type WebSocketSubProtocol = String; +pub type WebSocketSubProtocol = String<24>; /// Websocket send message type used when sending a websocket frame #[derive(PartialEq, Debug, Copy, Clone)] @@ -680,7 +678,7 @@ where if self.state == WebSocketState::Open { self.state = WebSocketState::CloseSent; if let Some(status_description) = status_description { - let mut from_buffer: Vec = Vec::new(); + let mut from_buffer: Vec = Vec::new(); BigEndian::write_u16(&mut from_buffer, close_status.to_u16()); // restrict the max size of the status_description @@ -690,7 +688,7 @@ where 254 }; - from_buffer.extend(status_description[..len].as_bytes()); + from_buffer.extend_from_slice(status_description[..len].as_bytes())?; self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true) } else { let mut from_buffer: [u8; 2] = [0; 2]; From c6169603dfca44bc009430d988dee7fce82f743b Mon Sep 17 00:00:00 2001 From: Steve Fan <29133953+stevefan1999-personal@users.noreply.github.com> Date: Fri, 24 Jun 2022 15:14:24 +0000 Subject: [PATCH 2/2] add simd support --- Cargo.toml | 2 ++ src/http.rs | 23 +++++++++++++++++++++-- src/lib.rs | 9 +++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ade4717..0150840 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ byteorder = { version = "1.4.3", default-features = false } httparse = { version = "1.7.1", default-features = false } rand_core = "0.6.3" base64 = { version = "0.13.0", default-features = false } +base64-simd = { version = "0.5.0", default-features = false, optional = true } +cfg-if = "1.0.0" [dev-dependencies] rand = "0.8.5" diff --git a/src/http.rs b/src/http.rs index 7d60836..56778bf 100644 --- a/src/http.rs +++ b/src/http.rs @@ -135,7 +135,17 @@ pub fn build_connect_handshake_request( let mut key: [u8; 16] = [0; 16]; rng.fill_bytes(&mut key); - base64::encode_config_slice(&key, base64::STANDARD, &mut key_as_base64); + + cfg_if::cfg_if! { + if #[cfg(feature = "base64-simd")] { + use base64_simd::{Base64, OutBuf}; + Base64::STANDARD.encode(&key, OutBuf::from_slice_mut(&mut key_as_base64))?; + } else { + base64::encode_config_slice(&key, base64::STANDARD, &mut key_as_base64); + } + } + + let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?); http_request.push_str("GET ")?; @@ -212,6 +222,15 @@ pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) let mut sha1 = Sha1::new(); sha1.update(&accept_string); let input = sha1.finalize(); - base64::encode_config_slice(&input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes + + cfg_if::cfg_if! { + if #[cfg(feature = "base64-simd")] { + use base64_simd::{Base64, OutBuf}; + Base64::STANDARD.encode(&input, OutBuf::from_slice_mut(output))?; + } else { + base64::encode_config_slice(&input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes + } + } + Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 37376b3..56552cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -211,6 +211,8 @@ pub enum Error { ConvertInfallible, RandCore, UnexpectedContinuationFrame, + #[cfg(feature = "base64-simd")] + Base64Error, } impl From for Error { @@ -237,6 +239,13 @@ impl From<()> for Error { } } +#[cfg(feature = "base64-simd")] +impl From for Error { + fn from(_: base64_simd::Error) -> Error { + Error::Base64Error + } +} + #[derive(Copy, Clone, Debug, PartialEq)] enum WebSocketOpCode { ContinuationFrame = 0,