Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 51 additions & 34 deletions src/encode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ use crate::num::{SignedVarIntTarget, VarIntTarget};
/// assert_eq!(encoded, ([185, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 2));
/// ```
#[inline]
#[cfg(any(target_feature = "sse2", doc))]
#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))]
pub fn encode<T: VarIntTarget>(num: T) -> ([u8; 16], u8) {
unsafe { encode_unsafe(num) }
}
Expand All @@ -35,8 +33,6 @@ pub fn encode<T: VarIntTarget>(num: T) -> ([u8; 16], u8) {
/// assert_eq!(encoded, ([39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1));
/// ```
#[inline]
#[cfg(any(target_feature = "sse2", doc))]
#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))]
pub fn encode_zigzag<T: SignedVarIntTarget>(num: T) -> ([u8; 16], u8) {
unsafe { encode_unsafe(T::Unsigned::zigzag(num)) }
}
Expand All @@ -48,8 +44,6 @@ pub fn encode_zigzag<T: SignedVarIntTarget>(num: T) -> ([u8; 16], u8) {
///
/// **Panics:** if the slice is too small to contain the varint.
#[inline]
#[cfg(any(target_feature = "sse2", doc))]
#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))]
pub fn encode_to_slice<T: VarIntTarget>(num: T, slice: &mut [u8]) -> u8 {
let (data, size) = encode(num);
slice[..size as usize].copy_from_slice(&data[..size as usize]);
Expand All @@ -66,8 +60,6 @@ pub fn encode_to_slice<T: VarIntTarget>(num: T, slice: &mut [u8]) -> u8 {
/// This should not have any unsafe behavior with any input. However, it still calls a large number
/// of unsafe functions.
#[inline]
#[cfg(any(target_feature = "sse2", doc))]
#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))]
pub unsafe fn encode_unsafe<T: VarIntTarget>(num: T) -> ([u8; 16], u8) {
if T::MAX_VARINT_BYTES <= 5 {
// We could kick off a lzcnt here on the original number but that makes the math complicated and slow
Expand All @@ -91,31 +83,56 @@ pub unsafe fn encode_unsafe<T: VarIntTarget>(num: T) -> ([u8; 16], u8) {
bytes_needed as u8,
)
} else {
// Break the number into 7-bit parts and spread them out into a vector
let stage1: __m128i = core::mem::transmute(num.num_to_vector_stage1());

// Create a mask for where there exist values
// This signed comparison works because all MSBs should be cleared at this point
// Also handle the special case when num == 0
let minimum = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xffu8 as i8);
let exists = _mm_or_si128(_mm_cmpgt_epi8(stage1, _mm_setzero_si128()), minimum);
let bits = _mm_movemask_epi8(exists);

// Count the number of bytes used
let bytes = 32 - bits.leading_zeros() as u8; // lzcnt on supported CPUs
// TODO: Compiler emits an unnecessary branch here when using bsr/bsl fallback

// Fill that many bytes into a vector
let ascend = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let mask = _mm_cmplt_epi8(ascend, _mm_set1_epi8(bytes as i8));

// Shift it down 1 byte so the last MSB is the only one set, and make sure only the MSB is set
let shift = _mm_bsrli_si128(mask, 1);
let msbmask = _mm_and_si128(shift, _mm_set1_epi8(128u8 as i8));

// Merge the MSB bits into the vector
let merged = _mm_or_si128(stage1, msbmask);

(core::mem::transmute::<__m128i, [u8; 16]>(merged), bytes)
#[cfg(any(target_feature = "sse2", doc))]
#[cfg_attr(rustc_nightly, doc(cfg(target_feature = "sse2")))]
{
// Break the number into 7-bit parts and spread them out into a vector
let stage1: __m128i = core::mem::transmute(num.num_to_vector_stage1());

// Create a mask for where there exist values
// This signed comparison works because all MSBs should be cleared at this point
// Also handle the special case when num == 0
let minimum = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xffu8 as i8);
let exists = _mm_or_si128(_mm_cmpgt_epi8(stage1, _mm_setzero_si128()), minimum);
let bits = _mm_movemask_epi8(exists);

// Count the number of bytes used
let bytes = 32 - bits.leading_zeros() as u8; // lzcnt on supported CPUs
// TODO: Compiler emits an unnecessary branch here when using bsr/bsl fallback

// Fill that many bytes into a vector
let ascend = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let mask = _mm_cmplt_epi8(ascend, _mm_set1_epi8(bytes as i8));

// Shift it down 1 byte so the last MSB is the only one set, and make sure only the MSB is set
let shift = _mm_bsrli_si128(mask, 1);
let msbmask = _mm_and_si128(shift, _mm_set1_epi8(128u8 as i8));

// Merge the MSB bits into the vector
let merged = _mm_or_si128(stage1, msbmask);

(core::mem::transmute::<__m128i, [u8; 16]>(merged), bytes)
}
#[cfg(not(target_feature = "sse2"))]
{
let stage1 = num.num_to_big_scalar_stage1();

// We could OR the data with 1 to avoid undefined behavior, but for some reason it's still faster to take the branch
let leading = stage1.leading_zeros();

let unused_bytes = (leading - 1) / 8;
let bytes_needed = 16 - unused_bytes;

// set all but the last MSBs
let msbs = 0x80808080808080808080808080808080;
let msbmask = u128::MAX >> ((16 - bytes_needed + 1) * 8 - 1);

let merged = stage1 | (msbs & msbmask);

(
core::mem::transmute::<[u128; 1], [u8; 16]>([merged]),
bytes_needed as u8,
)
}
}
}
23 changes: 23 additions & 0 deletions src/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ pub trait VarIntTarget: Debug + Eq + PartialEq + PartialOrd + Sized + Copy {
/// Splits this number into 7-bit segments for encoding
fn num_to_scalar_stage1(self) -> u64;

/// Same as `num_to_scalar_stage1`, but returns a u128 instead of a u64
/// This should be implemented when can't call num to scalar stage1 because the type is too big
fn num_to_big_scalar_stage1(self) -> u128 {
self.num_to_scalar_stage1() as u128
}

/// Splits this number into 7-bit segments for encoding
fn num_to_vector_stage1(self) -> [u8; 16];

Expand Down Expand Up @@ -460,6 +466,23 @@ impl VarIntTarget for u64 {
unsafe { core::mem::transmute(res) }
}

#[inline(always)]
#[rustfmt::skip]
fn num_to_big_scalar_stage1(self) -> u128 {
let x = self as u128;

(x & 0b0000000000000000000000000000000000000000000000000000000001111111)
| ((x & 0b0000000000000000000000000000000000000000000000000011111110000000) << 1)
| ((x & 0b0000000000000000000000000000000000000000000111111100000000000000) << 2)
| ((x & 0b0000000000000000000000000000000000001111111000000000000000000000) << 3)
| ((x & 0b0000000000000000000000000000011111110000000000000000000000000000) << 4)
| ((x & 0b0000000000000000000000111111100000000000000000000000000000000000) << 5)
| ((x & 0b0000000000000001111111000000000000000000000000000000000000000000) << 6)
| ((x & 0b0000000011111110000000000000000000000000000000000000000000000000) << 7)
| ((x & 0b0111111100000000000000000000000000000000000000000000000000000000) << 8)
| ((x & 0b1000000000000000000000000000000000000000000000000000000000000000) << 9)
}

#[inline(always)]
fn cast_u32(num: u32) -> Self {
num as u64
Expand Down