diff --git a/src/encode/mod.rs b/src/encode/mod.rs index ceae1ce..0cd6eb0 100644 --- a/src/encode/mod.rs +++ b/src/encode/mod.rs @@ -18,8 +18,6 @@ use crate::num::{SignedVarIntTarget, VarIntTarget}; /// assert_eq!(encoded, ([185, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 2)); /// ``` #[inline] -#[cfg(any(target_feature = "sse2", doc))] -#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))] pub fn encode(num: T) -> ([u8; 16], u8) { unsafe { encode_unsafe(num) } } @@ -35,8 +33,6 @@ pub fn encode(num: T) -> ([u8; 16], u8) { /// assert_eq!(encoded, ([39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1)); /// ``` #[inline] -#[cfg(any(target_feature = "sse2", doc))] -#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))] pub fn encode_zigzag(num: T) -> ([u8; 16], u8) { unsafe { encode_unsafe(T::Unsigned::zigzag(num)) } } @@ -48,8 +44,6 @@ pub fn encode_zigzag(num: T) -> ([u8; 16], u8) { /// /// **Panics:** if the slice is too small to contain the varint. #[inline] -#[cfg(any(target_feature = "sse2", doc))] -#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))] pub fn encode_to_slice(num: T, slice: &mut [u8]) -> u8 { let (data, size) = encode(num); slice[..size as usize].copy_from_slice(&data[..size as usize]); @@ -66,8 +60,6 @@ pub fn encode_to_slice(num: T, slice: &mut [u8]) -> u8 { /// This should not have any unsafe behavior with any input. However, it still calls a large number /// of unsafe functions. #[inline] -#[cfg(any(target_feature = "sse2", doc))] -#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))] pub unsafe fn encode_unsafe(num: T) -> ([u8; 16], u8) { if T::MAX_VARINT_BYTES <= 5 { // We could kick off a lzcnt here on the original number but that makes the math complicated and slow @@ -91,31 +83,56 @@ pub unsafe fn encode_unsafe(num: T) -> ([u8; 16], u8) { bytes_needed as u8, ) } else { - // Break the number into 7-bit parts and spread them out into a vector - let stage1: __m128i = core::mem::transmute(num.num_to_vector_stage1()); - - // Create a mask for where there exist values - // This signed comparison works because all MSBs should be cleared at this point - // Also handle the special case when num == 0 - let minimum = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xffu8 as i8); - let exists = _mm_or_si128(_mm_cmpgt_epi8(stage1, _mm_setzero_si128()), minimum); - let bits = _mm_movemask_epi8(exists); - - // Count the number of bytes used - let bytes = 32 - bits.leading_zeros() as u8; // lzcnt on supported CPUs - // TODO: Compiler emits an unnecessary branch here when using bsr/bsl fallback - - // Fill that many bytes into a vector - let ascend = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - let mask = _mm_cmplt_epi8(ascend, _mm_set1_epi8(bytes as i8)); - - // Shift it down 1 byte so the last MSB is the only one set, and make sure only the MSB is set - let shift = _mm_bsrli_si128(mask, 1); - let msbmask = _mm_and_si128(shift, _mm_set1_epi8(128u8 as i8)); - - // Merge the MSB bits into the vector - let merged = _mm_or_si128(stage1, msbmask); - - (core::mem::transmute::<__m128i, [u8; 16]>(merged), bytes) + #[cfg(any(target_feature = "sse2", doc))] + #[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))] + { + // Break the number into 7-bit parts and spread them out into a vector + let stage1: __m128i = core::mem::transmute(num.num_to_vector_stage1()); + + // Create a mask for where there exist values + // This signed comparison works because all MSBs should be cleared at this point + // Also handle the special case when num == 0 + let minimum = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xffu8 as i8); + let exists = _mm_or_si128(_mm_cmpgt_epi8(stage1, _mm_setzero_si128()), minimum); + let bits = _mm_movemask_epi8(exists); + + // Count the number of bytes used + let bytes = 32 - bits.leading_zeros() as u8; // lzcnt on supported CPUs + // TODO: Compiler emits an unnecessary branch here when using bsr/bsl fallback + + // Fill that many bytes into a vector + let ascend = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let mask = _mm_cmplt_epi8(ascend, _mm_set1_epi8(bytes as i8)); + + // Shift it down 1 byte so the last MSB is the only one set, and make sure only the MSB is set + let shift = _mm_bsrli_si128(mask, 1); + let msbmask = _mm_and_si128(shift, _mm_set1_epi8(128u8 as i8)); + + // Merge the MSB bits into the vector + let merged = _mm_or_si128(stage1, msbmask); + + (core::mem::transmute::<__m128i, [u8; 16]>(merged), bytes) + } + #[cfg(not(target_feature = "sse2"))] + { + let stage1 = num.num_to_big_scalar_stage1(); + + // We could OR the data with 1 to avoid undefined behavior, but for some reason it's still faster to take the branch + let leading = stage1.leading_zeros(); + + let unused_bytes = (leading - 1) / 8; + let bytes_needed = 16 - unused_bytes; + + // set all but the last MSBs + let msbs = 0x80808080808080808080808080808080; + let msbmask = u128::MAX >> ((16 - bytes_needed + 1) * 8 - 1); + + let merged = stage1 | (msbs & msbmask); + + ( + core::mem::transmute::<[u128; 1], [u8; 16]>([merged]), + bytes_needed as u8, + ) + } } } diff --git a/src/num.rs b/src/num.rs index bb848e2..34ef7f5 100644 --- a/src/num.rs +++ b/src/num.rs @@ -35,6 +35,12 @@ pub trait VarIntTarget: Debug + Eq + PartialEq + PartialOrd + Sized + Copy { /// Splits this number into 7-bit segments for encoding fn num_to_scalar_stage1(self) -> u64; + /// Same as `num_to_scalar_stage1`, but returns a u128 instead of a u64 + /// This should be implemented when can't call num to scalar stage1 because the type is too big + fn num_to_big_scalar_stage1(self) -> u128 { + self.num_to_scalar_stage1() as u128 + } + /// Splits this number into 7-bit segments for encoding fn num_to_vector_stage1(self) -> [u8; 16]; @@ -460,6 +466,23 @@ impl VarIntTarget for u64 { unsafe { core::mem::transmute(res) } } + #[inline(always)] + #[rustfmt::skip] + fn num_to_big_scalar_stage1(self) -> u128 { + let x = self as u128; + + (x & 0b0000000000000000000000000000000000000000000000000000000001111111) + | ((x & 0b0000000000000000000000000000000000000000000000000011111110000000) << 1) + | ((x & 0b0000000000000000000000000000000000000000000111111100000000000000) << 2) + | ((x & 0b0000000000000000000000000000000000001111111000000000000000000000) << 3) + | ((x & 0b0000000000000000000000000000011111110000000000000000000000000000) << 4) + | ((x & 0b0000000000000000000000111111100000000000000000000000000000000000) << 5) + | ((x & 0b0000000000000001111111000000000000000000000000000000000000000000) << 6) + | ((x & 0b0000000011111110000000000000000000000000000000000000000000000000) << 7) + | ((x & 0b0111111100000000000000000000000000000000000000000000000000000000) << 8) + | ((x & 0b1000000000000000000000000000000000000000000000000000000000000000) << 9) + } + #[inline(always)] fn cast_u32(num: u32) -> Self { num as u64