diff --git a/Cargo.toml b/Cargo.toml index 00a51df..4640209 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,16 +9,19 @@ default = ["concurrent_stat"] concurrent_stat = [] [dependencies] +arr_macro = "0.1.3" crossbeam-epoch = "0.9.5" crossbeam-utils = "0.8.5" +either = "1.6.1" +parking_lot = "0.12.1" rand = "0.8.4" thread_local = "1.1.4" -parking_lot = "0.12.1" +uninit = "0.5.1" [dev-dependencies] criterion = "0.3.4" -num_cpus = "1.13.0" crossbeam-queue = "0.3.5" +num_cpus = "1.13.0" [[bench]] name = "stack" diff --git a/benches/btree.rs b/benches/btree.rs index 0a4fa29..01a1769 100644 --- a/benches/btree.rs +++ b/benches/btree.rs @@ -22,7 +22,7 @@ const OPS_RATE: [(usize, usize, usize); 7] = [ (50, 0, 50), ]; -fn bench_vs_btreemap(c: &mut Criterion) { +fn bench_u64_vs_btreemap(c: &mut Criterion) { for (insert, lookup, remove) in OPS_RATE { let logs = fuzz_sequential_logs( 200, @@ -42,12 +42,10 @@ fn bench_vs_btreemap(c: &mut Criterion) { group.throughput(Throughput::Elements(MAP_TOTAL_OPS as u64)); bench_logs_btreemap(logs.clone(), &mut group); - bench_logs_sequential_map::>("BTree", logs.clone(), &mut group); - bench_logs_sequential_map::>("AVLTree", logs, &mut group); + bench_logs_sequential_map::>("BTree", logs.clone(), &mut group); + bench_logs_sequential_map::>("AVLTree", logs, &mut group); } } -criterion_group!(bench, bench_vs_btreemap); -criterion_main! { - bench, -} +criterion_group!(bench, bench_u64_vs_btreemap); +criterion_main! {bench} diff --git a/benches/util/sequential.rs b/benches/util/sequential.rs index f59b34d..ce2feb3 100644 --- a/benches/util/sequential.rs +++ b/benches/util/sequential.rs @@ -3,43 +3,60 @@ use std::{ time::{Duration, Instant}, }; -use cds::{map::SequentialMap, queue::SequentialQueue}; +use cds::{map::SequentialMap, queue::SequentialQueue, util::random::Random}; use criterion::{black_box, measurement::WallTime, BenchmarkGroup}; use rand::{prelude::SliceRandom, thread_rng, Rng}; #[derive(Clone, Copy)] -pub enum Op { - Insert(u64), - Lookup(u64), - Remove(u64), +pub enum Op { + Insert(K, V), + Lookup(K), + Remove(K), } -pub fn fuzz_sequential_logs( +type Logs = Vec<(Vec<(K, V)>, Vec>)>; + +pub fn fuzz_sequential_logs( iters: u64, already_inserted: u64, insert: usize, lookup: usize, remove: usize, -) -> Vec<(Vec, Vec)> { +) -> Logs { let mut rng = thread_rng(); let mut result = Vec::new(); for _ in 0..iters { let mut logs = Vec::new(); - let mut pre_inserted: Vec = (0..already_inserted).collect(); - pre_inserted.shuffle(&mut rng); + let mut pre_inserted = Vec::new(); + + for _ in 0..already_inserted { + pre_inserted.push((K::gen(&mut rng), V::gen(&mut rng))); + } for _ in 0..insert { - logs.push(Op::Insert(rng.gen_range(already_inserted..u64::MAX))); + logs.push(Op::Insert(K::gen(&mut rng), V::gen(&mut rng))); } - for _ in 0..lookup { - logs.push(Op::Lookup(rng.gen_range(0..already_inserted))); + for i in 0..lookup { + if i % 2 == 0 { + logs.push(Op::Lookup(K::gen(&mut rng))); + } else { + logs.push(Op::Lookup( + pre_inserted.choose(&mut rng).cloned().unwrap().0, + )); + } } - for _ in 0..remove { - logs.push(Op::Remove(rng.gen_range(0..already_inserted))); + for i in 0..remove { + if i % 2 == 0 { + logs.push(Op::Remove(K::gen(&mut rng))); + } else { + logs.push(Op::Remove( + pre_inserted.choose(&mut rng).cloned().unwrap().0, + )); + } } logs.shuffle(&mut rng); @@ -84,7 +101,7 @@ where }); } -pub fn bench_logs_btreemap(mut logs: Vec<(Vec, Vec)>, c: &mut BenchmarkGroup) { +pub fn bench_logs_btreemap(mut logs: Logs, c: &mut BenchmarkGroup) { c.bench_function("std::BTreeMap", |b| { b.iter_custom(|iters| { let mut duration = Duration::ZERO; @@ -94,15 +111,15 @@ pub fn bench_logs_btreemap(mut logs: Vec<(Vec, Vec)>, c: &mut Benchmark let mut map = BTreeMap::new(); // pre-insert - for key in pre_inserted { - let _ = map.insert(key, key); + for (key, value) in pre_inserted { + let _ = map.insert(key, value); } let start = Instant::now(); for op in logs { match op { - Op::Insert(key) => { - let _ = black_box(map.insert(key, key)); + Op::Insert(key, value) => { + let _ = black_box(map.insert(key, value)); } Op::Lookup(key) => { let _ = black_box(map.get(&key)); @@ -120,12 +137,13 @@ pub fn bench_logs_btreemap(mut logs: Vec<(Vec, Vec)>, c: &mut Benchmark }); } -pub fn bench_logs_sequential_map( +pub fn bench_logs_sequential_map( name: &str, - mut logs: Vec<(Vec, Vec)>, + mut logs: Logs, c: &mut BenchmarkGroup, ) where - M: SequentialMap, + K: Eq + Random, + M: SequentialMap, { c.bench_function(name, |b| { b.iter_custom(|iters| { @@ -136,15 +154,15 @@ pub fn bench_logs_sequential_map( let mut map = M::new(); // pre-insert - for key in pre_inserted { - let _ = map.insert(&key, key); + for (key, value) in pre_inserted { + let _ = map.insert(&key, value); } let start = Instant::now(); for op in logs { match op { - Op::Insert(key) => { - let _ = black_box(map.insert(&key, key)); + Op::Insert(key, value) => { + let _ = black_box(map.insert(&key, value)); } Op::Lookup(key) => { let _ = black_box(map.lookup(&key)); diff --git a/src/art/mod.rs b/src/art/mod.rs new file mode 100644 index 0000000..2cf9163 --- /dev/null +++ b/src/art/mod.rs @@ -0,0 +1,1430 @@ +mod utf8art; + +use std::{ + cmp::{min, Ordering}, + marker::PhantomData, + mem, + ptr::{self, NonNull}, +}; + +use arr_macro::arr; +use either::Either; +use std::fmt::Debug; + +use crate::{ + left_or, + map::SequentialMap, + util::{slice_insert, slice_remove}, +}; + +const PREFIX_LEN: usize = 12; +const KEY_ENDMARK: u8 = 0xff; // invalid on utf-8. Thus, use it for preventing that any key cannot be the prefix of another key. +struct NodeHeader { + len: u32, // the len of prefix + prefix: [u8; PREFIX_LEN], // prefix for path compression +} + +impl Debug for NodeHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unsafe { + f.debug_struct("NodeHeader") + .field("len", &self.len) + .field( + "prefix", + &self + .prefix + .get_unchecked(..min(PREFIX_LEN, self.len as usize)), + ) + .finish() + } + } +} + +impl Default for NodeHeader { + #[allow(deprecated)] + fn default() -> Self { + unsafe { + Self { + len: 0, + prefix: mem::uninitialized(), + } + } + } +} + +/// the child node type +/// This is used for bitflag on child pointer. +const NODETYPE_MASK: usize = 0b111; +#[repr(usize)] +#[derive(Debug, PartialEq)] +enum NodeType { + Value = 0b000, + Node4 = 0b001, + Node16 = 0b010, + Node48 = 0b011, + Node256 = 0b100, +} + +trait NodeOps { + fn header(&self) -> &NodeHeader; + fn header_mut(&mut self) -> &mut NodeHeader; + fn size(&self) -> usize; + fn is_full(&self) -> bool; + fn is_shrinkable(&self) -> bool; + fn get_any_child(&self) -> Option<&NodeV>; + fn insert(&mut self, key: u8, node: Node) -> Result<(), Node>; + fn lookup(&self, key: u8) -> Option<&Node>; + fn lookup_mut(&mut self, key: u8) -> Option<&mut Node>; + fn update(&mut self, key: u8, node: Node) -> Result, Node>; + fn remove(&mut self, key: u8) -> Result, ()>; +} + +/// the pointer struct for Nodes or value +struct Node { + pointer: usize, + _marker: PhantomData>, +} + +impl Debug for Node { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unsafe { + let pointer = self.pointer & !NODETYPE_MASK; + let tag = mem::transmute(self.pointer & NODETYPE_MASK); + + match tag { + NodeType::Value => (&*(pointer as *const NodeV)).fmt(f), + NodeType::Node4 => (&*(pointer as *const Node4)).fmt(f), + NodeType::Node16 => (&*(pointer as *const Node16)).fmt(f), + NodeType::Node48 => (&*(pointer as *const Node48)).fmt(f), + NodeType::Node256 => (&*(pointer as *const Node256)).fmt(f), + } + } + } +} + +impl Node { + fn deref(&self) -> Either<&dyn NodeOps, &NodeV> { + unsafe { + let pointer = self.pointer & !NODETYPE_MASK; + let tag = mem::transmute(self.pointer & NODETYPE_MASK); + + match tag { + NodeType::Value => Either::Right(&*(pointer as *const NodeV)), + NodeType::Node4 => Either::Left(&*(pointer as *const Node4)), + NodeType::Node16 => Either::Left(&*(pointer as *const Node16)), + NodeType::Node48 => Either::Left(&*(pointer as *const Node48)), + NodeType::Node256 => Either::Left(&*(pointer as *const Node256)), + } + } + } + + fn deref_mut(&self) -> Either<&mut dyn NodeOps, &mut NodeV> { + unsafe { + let pointer = self.pointer & !NODETYPE_MASK; + let tag = mem::transmute(self.pointer & NODETYPE_MASK); + + match tag { + NodeType::Value => Either::Right(&mut *(pointer as *mut NodeV)), + NodeType::Node4 => Either::Left(&mut *(pointer as *mut Node4)), + NodeType::Node16 => Either::Left(&mut *(pointer as *mut Node16)), + NodeType::Node48 => Either::Left(&mut *(pointer as *mut Node48)), + NodeType::Node256 => Either::Left(&mut *(pointer as *mut Node256)), + } + } + } + + fn inner(self) -> Box { + // TODO: how to improve this function safely(self.node_type() == T::node_type()) + unsafe { + let pointer = self.pointer & !NODETYPE_MASK; + // let tag = mem::transmute(self.pointer & NODETYPE_MASK); + + Box::from_raw(pointer as *mut T) + } + } + + fn new(node: T, node_type: NodeType) -> Self { + let node = Box::into_raw(Box::new(node)); + + Self { + pointer: node as usize | node_type as usize, + _marker: PhantomData, + } + } + + const fn null() -> Self { + Self { + pointer: 0, + _marker: PhantomData, + } + } + + #[inline] + fn is_null(&self) -> bool { + self.pointer == 0 + } + + fn node_type(&self) -> NodeType { + unsafe { mem::transmute(self.pointer & NODETYPE_MASK) } + } + + /// extend node to bigger one only if necessary + fn extend(&mut self) { + if self.deref().is_right() { + return; + } + + if !self.deref().left().unwrap().is_full() { + return; + } + + let node_type = self.node_type(); + let node = self.deref_mut().left().unwrap(); + + match node_type { + NodeType::Value => unreachable!(), + NodeType::Node4 => unsafe { + let node = node as *const dyn NodeOps as *const Node4; + let new = Box::new(Node16::from(ptr::read(node))); + self.pointer = Box::into_raw(new) as usize | NodeType::Node16 as usize; + }, + NodeType::Node16 => unsafe { + let node = node as *const dyn NodeOps as *const Node16; + let new = Box::new(Node48::from(ptr::read(node))); + self.pointer = Box::into_raw(new) as usize | NodeType::Node48 as usize; + }, + NodeType::Node48 => unsafe { + let node = node as *const dyn NodeOps as *const Node48; + let new = Box::new(Node256::from(ptr::read(node))); + self.pointer = Box::into_raw(new) as usize | NodeType::Node256 as usize; + }, + NodeType::Node256 => {} + } + } + + /// shrink node to smaller one only if necessary + fn shrink(&mut self) { + if self.deref().is_right() { + return; + } + + if !self.deref().left().unwrap().is_shrinkable() { + return; + } + + let node_type = self.node_type(); + let node = self.deref_mut().left().unwrap(); + + match node_type { + NodeType::Value => unreachable!(), + NodeType::Node4 => {} + NodeType::Node16 => unsafe { + let node = node as *const dyn NodeOps as *const Node16; + let new = Box::new(Node4::from(ptr::read(node))); + self.pointer = Box::into_raw(new) as usize | NodeType::Node4 as usize; + }, + NodeType::Node48 => unsafe { + let node = node as *const dyn NodeOps as *const Node48; + let new = Box::new(Node16::from(ptr::read(node))); + self.pointer = Box::into_raw(new) as usize | NodeType::Node16 as usize; + }, + NodeType::Node256 => unsafe { + let node = node as *const dyn NodeOps as *const Node256; + let new = Box::new(Node48::from(ptr::read(node))); + self.pointer = Box::into_raw(new) as usize | NodeType::Node48 as usize; + }, + } + } + + /// compress path if the node is Node4 with having one child + /// If self's unique one child is not NodeV(internal node), then compress path from self.header to + /// self.header + self.key(of child) + child.header and set child on self. + /// If self's one is NodeV(external node), just set child on self.(not need to compress path on header). + fn compress_path(&mut self) { + if self.node_type() != NodeType::Node4 { + return; + } + + if self.deref().left().unwrap().size() != 1 { + return; + } + + unsafe { + let node = Box::from_raw((self.pointer & !NODETYPE_MASK) as *mut Node4); + + let child_key = *node.keys.get_unchecked(0); + let child = ptr::read(node.children.get_unchecked(0)); + + // if the child is not NodeV, then move prefix from parent to child + if let Either::Left(child) = child.deref_mut() { + // push child key on front of child header prefix + let prefix_ptr = child.header_mut().prefix.as_mut_ptr(); + let prefix_len = child.header().len as usize; + + ptr::copy( + prefix_ptr, + prefix_ptr.add(1), + min(prefix_len, PREFIX_LEN - 1), + ); + *prefix_ptr = child_key; + + child.header_mut().len += 1; + + if node.header.len > 0 { + let node_prefix_len = node.header.len as usize; + let prefix_len = child.header().len as usize; + + if PREFIX_LEN > node_prefix_len { + ptr::copy( + prefix_ptr, + prefix_ptr.add(node_prefix_len as usize), + min(prefix_len, PREFIX_LEN - node_prefix_len), + ); + } + + ptr::copy_nonoverlapping( + node.header.prefix.as_ptr(), + prefix_ptr, + min(node_prefix_len, PREFIX_LEN), + ); + + child.header_mut().len = (prefix_len + node_prefix_len) as u32; + } + } + + mem::forget(node); + *self = child; + } + } + + /// compare the keys from depth to header.len + fn prefix_match(keys: &[u8], node: &dyn NodeOps, depth: usize) -> Result<(), usize> { + let header = node.header(); + + for (index, prefix) in unsafe { + header + .prefix + .get_unchecked(..min(PREFIX_LEN, header.len as usize)) + .iter() + .enumerate() + } { + if keys[depth + index] != *prefix { + return Err(depth + index); + } + } + + if header.len > PREFIX_LEN as u32 { + // check strictly by using leaf node + let any_child = node.get_any_child().unwrap(); + + let mut d = depth + PREFIX_LEN; + + while d < depth + header.len as usize { + if keys[d] != any_child.key[d] { + return Err(d); + } + + d += 1; + } + } + + Ok(()) + } +} + +struct NodeV { + key: Box<[u8]>, + value: V, +} + +impl Debug for NodeV { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NodeV") + .field("key", &self.key) + .field("value", &self.value) + .finish() + } +} + +impl NodeV { + fn new(key: Vec, value: V) -> Self { + Self { + key: key.into(), + value, + } + } +} + +struct Node4 { + header: NodeHeader, + len: usize, + keys: [u8; 4], + children: [Node; 4], +} + +impl Debug for Node4 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Node4") + .field("header", &self.header) + .field("len", &self.len) + .field("keys", &self.keys()) + .field("children", &self.children()) + .finish() + } +} + +impl Default for Node4 { + #[allow(deprecated)] + fn default() -> Self { + unsafe { + Self { + header: Default::default(), + len: 0, + keys: mem::uninitialized(), + children: mem::uninitialized(), + } + } + } +} + +impl From> for Node4 { + fn from(node: Node16) -> Self { + debug_assert!(node.len <= 4); + + let mut new = Self::default(); + new.header = node.header; + new.len = node.len; + + unsafe { + ptr::copy_nonoverlapping(node.keys.as_ptr(), new.keys.as_mut_ptr(), node.len as usize); + ptr::copy_nonoverlapping( + node.children.as_ptr(), + new.children.as_mut_ptr(), + node.len as usize, + ); + } + + new + } +} + +impl Node4 { + fn keys(&self) -> &[u8] { + unsafe { self.keys.get_unchecked(..self.len as usize) } + } + + fn mut_keys(&mut self) -> &mut [u8] { + unsafe { self.keys.get_unchecked_mut(..self.len as usize) } + } + + fn children(&self) -> &[Node] { + unsafe { self.children.get_unchecked(..self.len as usize) } + } + + fn mut_children(&mut self) -> &mut [Node] { + unsafe { self.children.get_unchecked_mut(..self.len as usize) } + } +} + +impl NodeOps for Node4 { + fn header(&self) -> &NodeHeader { + &self.header + } + + fn header_mut(&mut self) -> &mut NodeHeader { + &mut self.header + } + + fn size(&self) -> usize { + self.len + } + + fn is_full(&self) -> bool { + self.len == 4 + } + + fn is_shrinkable(&self) -> bool { + false + } + + fn get_any_child(&self) -> Option<&NodeV> { + debug_assert!(self.size() > 0); + + match unsafe { self.children.get_unchecked(0).deref() } { + Either::Left(node) => node.get_any_child(), + Either::Right(nodev) => return Some(nodev), + } + } + + fn insert(&mut self, key: u8, node: Node) -> Result<(), Node> { + debug_assert!(!self.is_full()); + + for (index, k) in self.keys().iter().enumerate() { + match key.cmp(k) { + Ordering::Less => unsafe { + self.len += 1; + slice_insert(self.mut_keys(), index, key); + slice_insert(self.mut_children(), index, node); + return Ok(()); + }, + Ordering::Equal => return Err(node), + Ordering::Greater => {} + } + } + + let index = self.len; + unsafe { + self.len += 1; + slice_insert(self.mut_keys(), index, key); + slice_insert(self.mut_children(), index, node); + } + + Ok(()) + } + + fn lookup(&self, key: u8) -> Option<&Node> { + for (index, k) in self.keys().iter().enumerate() { + if key == *k { + return unsafe { Some(self.children.get_unchecked(index)) }; + } + } + + None + } + + fn lookup_mut(&mut self, key: u8) -> Option<&mut Node> { + for (index, k) in self.keys().iter().enumerate() { + if key == *k { + return unsafe { Some(self.children.get_unchecked_mut(index)) }; + } + } + + None + } + + fn update(&mut self, key: u8, node: Node) -> Result, Node> { + for (index, k) in self.keys().iter().enumerate() { + match key.cmp(k) { + Ordering::Less => {} + Ordering::Equal => unsafe { + let node = mem::replace(self.children.get_unchecked_mut(index), node); + return Ok(node); + }, + Ordering::Greater => {} + } + } + + Err(node) + } + + fn remove(&mut self, key: u8) -> Result, ()> { + debug_assert!(self.len != 0); + + for (index, k) in self.keys().iter().enumerate() { + match key.cmp(k) { + Ordering::Less => {} + Ordering::Equal => unsafe { + let _ = slice_remove(self.mut_keys(), index); + let node = slice_remove(self.mut_children(), index); + self.len -= 1; + return Ok(node); + }, + Ordering::Greater => {} + } + } + + Err(()) + } +} + +struct Node16 { + header: NodeHeader, + len: usize, + keys: [u8; 16], + children: [Node; 16], +} + +impl Debug for Node16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Node16") + .field("header", &self.header) + .field("len", &self.len) + .field("keys", &self.keys()) + .field("children", &self.children()) + .finish() + } +} + +impl Default for Node16 { + #[allow(deprecated)] + fn default() -> Self { + unsafe { + Self { + header: Default::default(), + len: 0, + keys: mem::uninitialized(), + children: mem::uninitialized(), + } + } + } +} + +impl From> for Node16 { + fn from(node: Node4) -> Self { + debug_assert!(node.len == 4); + + let mut new = Self::default(); + new.header = node.header; + new.len = node.len; + + unsafe { + ptr::copy_nonoverlapping(node.keys.as_ptr(), new.keys.as_mut_ptr(), node.len as usize); + ptr::copy_nonoverlapping( + node.children.as_ptr(), + new.children.as_mut_ptr(), + node.len as usize, + ); + } + + new + } +} + +impl From> for Node16 { + fn from(node: Node48) -> Self { + debug_assert!(node.len <= 16); + + let mut new = Self::default(); + new.header = node.header; + new.len = node.len; + + unsafe { + let mut i = 0; + for (key, index) in node.keys.iter().enumerate() { + if *index != 0xff { + *new.keys.get_unchecked_mut(i) = key as u8; + *new.children.get_unchecked_mut(i) = + ptr::read(node.children.get_unchecked(*index as usize)); + i += 1; + } + } + + debug_assert!(i <= 16); + } + + new + } +} + +impl Node16 { + fn keys(&self) -> &[u8] { + unsafe { self.keys.get_unchecked(..self.len as usize) } + } + + fn mut_keys(&mut self) -> &mut [u8] { + unsafe { self.keys.get_unchecked_mut(..self.len as usize) } + } + + fn children(&self) -> &[Node] { + unsafe { self.children.get_unchecked(..self.len as usize) } + } + + fn mut_children(&mut self) -> &mut [Node] { + unsafe { self.children.get_unchecked_mut(..self.len as usize) } + } +} + +impl NodeOps for Node16 { + fn header(&self) -> &NodeHeader { + &self.header + } + + fn header_mut(&mut self) -> &mut NodeHeader { + &mut self.header + } + + fn size(&self) -> usize { + self.len + } + + fn is_full(&self) -> bool { + self.len == 16 + } + + fn is_shrinkable(&self) -> bool { + self.len <= 4 + } + + fn get_any_child(&self) -> Option<&NodeV> { + debug_assert!(self.size() > 0); + + match unsafe { self.children.get_unchecked(0).deref() } { + Either::Left(node) => node.get_any_child(), + Either::Right(nodev) => Some(nodev), + } + } + + fn insert(&mut self, key: u8, node: Node) -> Result<(), Node> { + debug_assert!(!self.is_full()); + + for (index, k) in self.keys().iter().enumerate() { + match key.cmp(k) { + Ordering::Less => unsafe { + self.len += 1; + slice_insert(self.mut_keys(), index, key); + slice_insert(self.mut_children(), index, node); + return Ok(()); + }, + Ordering::Equal => return Err(node), + Ordering::Greater => {} + } + } + + let index = self.len; + unsafe { + self.len += 1; + slice_insert(self.mut_keys(), index, key); + slice_insert(self.mut_children(), index, node); + } + + Ok(()) + } + + fn lookup(&self, key: u8) -> Option<&Node> { + for (index, k) in self.keys().iter().enumerate() { + if key == *k { + return unsafe { Some(self.children.get_unchecked(index)) }; + } + } + + None + } + + fn lookup_mut(&mut self, key: u8) -> Option<&mut Node> { + for (index, k) in self.keys().iter().enumerate() { + if key == *k { + return unsafe { Some(self.children.get_unchecked_mut(index)) }; + } + } + + None + } + + fn update(&mut self, key: u8, node: Node) -> Result, Node> { + for (index, k) in self.keys().iter().enumerate() { + match key.cmp(k) { + Ordering::Less => {} + Ordering::Equal => unsafe { + let node = mem::replace(self.children.get_unchecked_mut(index), node); + return Ok(node); + }, + Ordering::Greater => {} + } + } + + Err(node) + } + + fn remove(&mut self, key: u8) -> Result, ()> { + debug_assert!(self.len != 0); + + for (index, k) in self.keys().iter().enumerate() { + match key.cmp(k) { + Ordering::Less => {} + Ordering::Equal => unsafe { + let _ = slice_remove(self.mut_keys(), index); + let node = slice_remove(self.mut_children(), index); + self.len -= 1; + return Ok(node); + }, + Ordering::Greater => {} + } + } + + Err(()) + } +} +struct Node48 { + header: NodeHeader, + len: usize, + keys: [u8; 256], + children: [Node; 48], +} + +impl Debug for Node48 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let valid_keys = self + .keys + .iter() + .enumerate() + .filter(|(_, index)| **index != 0xff) + .map(|(key, _)| key) + .collect::>(); + + let valid_children = valid_keys + .iter() + .map(|key| &self.children[self.keys[*key] as usize]) + .collect::>(); + + f.debug_struct("Node48") + .field("header", &self.header) + .field("len", &self.len) + .field("keys", &valid_keys) + .field("children", &valid_children) + .finish() + } +} + +impl Default for Node48 { + #[allow(deprecated)] + fn default() -> Self { + Self { + header: Default::default(), + len: 0, + keys: arr![0xff; 256], // the invalid index is 0xff + children: arr![Node::null(); 48], + } + } +} + +impl From> for Node48 { + fn from(node: Node16) -> Self { + debug_assert!(node.len == 16); + + let mut new = Self::default(); + + unsafe { + for (index, key) in node.keys().iter().enumerate() { + *new.keys.get_unchecked_mut(*key as usize) = index as u8; + } + + ptr::copy_nonoverlapping( + node.children.as_ptr(), + new.children.as_mut_ptr(), + node.len as usize, + ); + } + + new.header = node.header; + new.len = node.len; + + new + } +} + +impl From> for Node48 { + fn from(node: Node256) -> Self { + debug_assert!(node.len <= 48); + + let mut new = Self::default(); + + unsafe { + for (key, child) in node.children.iter().enumerate() { + if !child.is_null() { + *new.keys.get_unchecked_mut(key) = new.len as u8; + *new.children.get_unchecked_mut(new.len) = ptr::read(child); + new.len += 1; + } + } + } + + new.header = node.header; + + new + } +} + +impl NodeOps for Node48 { + fn header(&self) -> &NodeHeader { + &self.header + } + + fn header_mut(&mut self) -> &mut NodeHeader { + &mut self.header + } + + fn size(&self) -> usize { + self.len + } + + fn is_full(&self) -> bool { + self.len == 48 + } + + fn is_shrinkable(&self) -> bool { + self.len <= 16 + } + + fn get_any_child(&self) -> Option<&NodeV> { + debug_assert!(self.size() > 0); + + match unsafe { self.children.get_unchecked(0).deref() } { + Either::Left(node) => node.get_any_child(), + Either::Right(nodev) => Some(nodev), + } + } + + fn insert(&mut self, key: u8, node: Node) -> Result<(), Node> { + debug_assert!(!self.is_full()); + + let index = unsafe { self.keys.get_unchecked_mut(key as usize) }; + + if *index != 0xff { + Err(node) + } else { + for (idx, child) in self.children.iter_mut().enumerate() { + if child.is_null() { + *child = node; + *index = idx as u8; + self.len += 1; + return Ok(()); + } + } + + unreachable!() + } + } + + fn lookup(&self, key: u8) -> Option<&Node> { + let index = unsafe { self.keys.get_unchecked(key as usize) }; + + if *index == 0xff { + None + } else { + unsafe { Some(self.children.get_unchecked(*index as usize)) } + } + } + + fn lookup_mut(&mut self, key: u8) -> Option<&mut Node> { + let index = unsafe { self.keys.get_unchecked(key as usize) }; + + if *index == 0xff { + None + } else { + unsafe { Some(self.children.get_unchecked_mut(*index as usize)) } + } + } + + fn update(&mut self, key: u8, node: Node) -> Result, Node> { + let index = unsafe { self.keys.get_unchecked_mut(key as usize) }; + + if *index == 0xff { + Err(node) + } else { + let child = unsafe { self.children.get_unchecked_mut(*index as usize) }; + let old = mem::replace(child, node); + Ok(old) + } + } + + fn remove(&mut self, key: u8) -> Result, ()> { + let index = unsafe { self.keys.get_unchecked(key as usize).clone() }; + + if index == 0xff { + Err(()) + } else { + unsafe { + let node = mem::replace( + self.children.get_unchecked_mut(index as usize), + Node::null(), + ); + *self.keys.get_unchecked_mut(key as usize) = 0xff; + self.len -= 1; + Ok(node) + } + } + } +} + +struct Node256 { + header: NodeHeader, + len: usize, + children: [Node; 256], +} + +impl Debug for Node256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let valid_children = self + .children + .iter() + .enumerate() + .filter(|(_, child)| !child.is_null()) + .collect::>(); + + f.debug_struct("Node256") + .field("header", &self.header) + .field("len", &self.len) + .field("children", &valid_children) + .finish() + } +} + +impl Default for Node256 { + #[allow(deprecated)] + fn default() -> Self { + Self { + header: Default::default(), + len: 0, + children: arr![Node::null(); 256], + } + } +} + +impl From> for Node256 { + fn from(node: Node48) -> Self { + debug_assert!(node.len == 48); + + let mut new = Self::default(); + + unsafe { + for (key, index) in node.keys.iter().enumerate() { + if *index != 0xff { + *new.children.get_unchecked_mut(key) = + ptr::read(node.children.get_unchecked(*index as usize)); + } + } + } + + new.header = node.header; + new.len = node.len; + + new + } +} + +impl NodeOps for Node256 { + fn header(&self) -> &NodeHeader { + &self.header + } + + fn header_mut(&mut self) -> &mut NodeHeader { + &mut self.header + } + + fn size(&self) -> usize { + self.len + } + + fn is_full(&self) -> bool { + self.len == 256 + } + + fn is_shrinkable(&self) -> bool { + self.len <= 48 + } + + fn get_any_child(&self) -> Option<&NodeV> { + debug_assert!(self.size() > 0); + + for child in self.children.iter() { + if !child.is_null() { + return match child.deref() { + Either::Left(node) => node.get_any_child(), + Either::Right(nodev) => Some(nodev), + }; + } + } + + unreachable!() + } + + fn insert(&mut self, key: u8, node: Node) -> Result<(), Node> { + let child = unsafe { self.children.get_unchecked_mut(key as usize) }; + + if child.is_null() { + self.len += 1; + *child = node; + Ok(()) + } else { + Err(node) + } + } + + fn lookup(&self, key: u8) -> Option<&Node> { + let child = unsafe { self.children.get_unchecked(key as usize) }; + + if child.is_null() { + None + } else { + Some(child) + } + } + + fn lookup_mut(&mut self, key: u8) -> Option<&mut Node> { + let child = unsafe { self.children.get_unchecked_mut(key as usize) }; + + if child.is_null() { + None + } else { + Some(child) + } + } + + fn update(&mut self, key: u8, node: Node) -> Result, Node> { + let child = unsafe { self.children.get_unchecked_mut(key as usize) }; + + if child.is_null() { + Err(node) + } else { + let old = mem::replace(child, node); + Ok(old) + } + } + + fn remove(&mut self, key: u8) -> Result, ()> { + let child = unsafe { self.children.get_unchecked_mut(key as usize) }; + + if child.is_null() { + Err(()) + } else { + let node = mem::replace(child, Node::null()); + self.len -= 1; + Ok(node) + } + } +} + +pub trait Encodable { + fn encode(&self) -> Vec; +} + +impl Encodable for String { + fn encode(&self) -> Vec { + let mut array = self.clone().into_bytes(); + array.push(KEY_ENDMARK); // prevent to certain string cannot be the prefix of another string + array + } +} + +pub struct ART { + root: Node, + _marker: PhantomData, +} + +impl Debug for ART { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ART").field("root", &self.root).finish() + } +} + +impl Drop for ART { + fn drop(&mut self) { + fn clean(node: &Node) { + match node.node_type() { + NodeType::Value => unsafe { drop(ptr::read(node).inner::>()) }, + NodeType::Node4 => { + let node4 = unsafe { ptr::read(node).inner::>() }; + + for child in node4.children() { + clean(child); + } + } + NodeType::Node16 => { + let node16 = unsafe { ptr::read(node).inner::>() }; + + for child in node16.children() { + clean(child); + } + } + NodeType::Node48 => { + let node48 = unsafe { ptr::read(node).inner::>() }; + + for child in &node48.children { + if !child.is_null() { + clean(child); + } + } + } + NodeType::Node256 => { + let node256 = unsafe { ptr::read(node).inner::>() }; + + for child in &node256.children { + if !child.is_null() { + clean(child); + } + } + } + } + } + + clean(&self.root); + } +} + +impl ART { + pub fn print_debug_info(&self) { + println!("V is {:>5} byte.", mem::size_of::()); + println!("NodeV is {:>5} byte.", mem::size_of::>()); + println!("NodeHeader is {:>5} byte.", mem::size_of::()); + println!("Node is {:>5} byte.", mem::size_of::>()); + println!("Node4 is {:>5} byte.", mem::size_of::>()); + println!("Node16 is {:>5} byte.", mem::size_of::>()); + println!("Node48 is {:>5} byte.", mem::size_of::>()); + println!("Node256 is {:>5} byte.", mem::size_of::>()); + } +} + +impl SequentialMap for ART { + fn new() -> Self { + let root = Node::new(Node256::::default(), NodeType::Node256); + + Self { + root, + _marker: PhantomData, + } + } + + fn insert(&mut self, key: &K, value: V) -> Result<(), V> { + let keys = key.encode(); + let mut depth = 0; + let mut common_prefix: u32 = 0; + let mut current = NonNull::new(&mut self.root).unwrap(); + + while depth < keys.len() { + let current_ref = unsafe { current.as_mut() }; + let node = left_or!(current_ref.deref_mut(), break); + + if let Err(common_depth) = Node::prefix_match(&keys, node, depth) { + common_prefix = (common_depth - depth) as u32; + break; + } + + let prefix = node.header().len; + + if let Some(node) = node.lookup_mut(keys[depth + prefix as usize]) { + depth += 1 + prefix as usize; + current = NonNull::new(node).unwrap(); + } else { + common_prefix = prefix; + break; + } + } + + let current_ref = unsafe { current.as_mut() }; + current_ref.extend(); + + match current_ref.deref_mut() { + Either::Left(node) => { + let new = NodeV::new(keys.clone(), value); + + if common_prefix == node.header().len { + // just insert value into this node + // println!("just insert"); + let key = keys[depth + common_prefix as usize]; + let insert = node.insert(key, Node::new(new, NodeType::Value)); + debug_assert!(insert.is_ok()); + } else { + drop(node); // since the current(ref of node) will be changed, drop it for safety not to use it. + + // split prefix + let key = keys[depth + common_prefix as usize]; + let mut inter_node = Node4::::default(); + + unsafe { + ptr::copy_nonoverlapping( + keys.as_ptr().add(depth), + inter_node.header.prefix.as_mut_ptr(), + common_prefix as usize, + ); + } + inter_node.header.len = common_prefix; + + // replace with inter_node and get old node + let current = unsafe { current.as_mut() }; + let old = mem::replace(current, Node::new(inter_node, NodeType::Node4)); + let current = current.deref_mut().left().unwrap(); + + // get old's key and re-set the old's prefix + let old_ref = old.deref_mut().left().unwrap(); + let header = old_ref.header(); + + let old_key; + + if header.len > PREFIX_LEN as u32 { + // need to get omitted prefix from any child + let full_key = old_ref.get_any_child().unwrap().key.clone(); + let prefix_start = depth + common_prefix as usize + 1; + + let header = old_ref.header_mut(); + unsafe { + ptr::copy_nonoverlapping( + full_key.as_ptr().add(prefix_start), + header.prefix.as_mut_ptr(), + min( + PREFIX_LEN, + header.len as usize - (common_prefix + 1) as usize, + ), + ) + }; + header.len -= common_prefix + 1; + + old_key = + unsafe { *full_key.get_unchecked(depth + common_prefix as usize) }; + } else { + // just move prefix + old_key = unsafe { *header.prefix.get_unchecked(common_prefix as usize) }; + + let header = old_ref.header_mut(); + unsafe { + ptr::copy( + header.prefix.as_ptr().add(common_prefix as usize + 1), + header.prefix.as_mut_ptr(), + (header.len - (common_prefix + 1)) as usize, + ) + }; + header.len -= common_prefix + 1; + } + + let insert_old = current.insert(old_key, old); + debug_assert!(insert_old.is_ok()); + let insert_new = current.insert(key, Node::new(new, NodeType::Value)); + debug_assert!(insert_new.is_ok()); + } + + Ok(()) + } + Either::Right(nodev) => { + if *nodev.key == keys { + return Err(value); + } + + let new = NodeV::new(keys.clone(), value); + + // insert inter node with zero prefix + // ex) 'aE', 'aaE' + let mut common_prefix = 0; + + while keys[depth + common_prefix] == nodev.key[depth + common_prefix] { + common_prefix += 1; + } + + drop(nodev); // since the nodev will be changed, drop it for safety not to use it. + + let mut inter_node = Node4::::default(); + unsafe { + ptr::copy_nonoverlapping( + keys.as_ptr().add(depth), + inter_node.header.prefix.as_mut_ptr(), + min(PREFIX_LEN, common_prefix), + ); + } + inter_node.header.len = common_prefix as u32; + + let current = unsafe { current.as_mut() }; + let old = mem::replace(current, Node::new(inter_node, NodeType::Node4)); + let current = current.deref_mut().left().unwrap(); + + let old_full_key = &old.deref().right().unwrap().key; + let insert_old = current.insert(old_full_key[depth + common_prefix], old); + debug_assert!(insert_old.is_ok()); + let insert_new = + current.insert(keys[depth + common_prefix], Node::new(new, NodeType::Value)); + debug_assert!(insert_new.is_ok()); + + Ok(()) + } + } + } + + fn lookup(&self, key: &K) -> Option<&V> { + let keys = key.encode(); + let mut depth = 0; + + let mut current = &self.root; + + while depth < keys.len() { + let node = left_or!(current.deref(), break); + depth += node.header().len as usize; + + if depth >= keys.len() { + return None; + } + + if let Some(node) = node.lookup(keys[depth]) { + depth += 1; + current = node; + } else { + return None; + } + } + + match current.deref() { + Either::Left(_) => None, + Either::Right(nodev) => { + if *nodev.key == keys { + Some(&nodev.value) + } else { + None + } + } + } + } + + fn remove(&mut self, key: &K) -> Result { + let keys = key.encode(); + let mut depth = 0; + + let mut parent = None; + let mut current = NonNull::new(&mut self.root).unwrap(); + + while depth < keys.len() { + let current_ref = unsafe { current.as_mut() }; + let node = current_ref.deref_mut().unwrap_left(); + depth += node.header().len as usize; + + if depth >= keys.len() { + return Err(()); + } + + if let Some(node) = node.lookup_mut(keys[depth]) { + if node.node_type() == NodeType::Value { + if *node.deref().right().unwrap().key == keys { + break; + } else { + return Err(()); + } + } + + depth += 1; + parent = Some(current); + current = NonNull::new(node).unwrap(); + } else { + return Err(()); + } + } + + let current_ref = unsafe { current.as_mut() }; + let current_node = current_ref.deref_mut().left().unwrap(); + let node = current_node.remove(keys[depth]); + debug_assert!(node.is_ok()); + let node = node.unwrap().inner::>(); + + // if it can compress path for only one child, do it. + let current_ref = unsafe { current.as_mut() }; + current_ref.compress_path(); + + // if it was not removed since it had have at least one child, then + if let Either::Left(current_node) = current_ref.deref_mut() { + if let Some(mut parent) = parent { + if current_node.size() == 0 { + // remove the node + let parent = unsafe { parent.as_mut() }; + let parent_ref = parent.deref_mut().left().unwrap(); + + let remove = + parent_ref.remove(keys[depth - current_node.header().len as usize - 1]); + debug_assert!(remove.is_ok()); + let remove = remove.unwrap(); + debug_assert_eq!(remove.deref().left().unwrap().size(), 0); + debug_assert_eq!(remove.node_type(), NodeType::Node4); + remove.inner::>(); + } else if current_node.is_shrinkable() { + // shrink the node + current_ref.shrink(); + } + } + } + + Ok(node.value) + } +} diff --git a/src/art/utf8art/mod.rs b/src/art/utf8art/mod.rs new file mode 100644 index 0000000..73397a1 --- /dev/null +++ b/src/art/utf8art/mod.rs @@ -0,0 +1,48 @@ +use std::{cmp::min, fmt::Debug, mem::MaybeUninit}; + +use uninit::uninit_array; + +const PREFIX_LEN: usize = 12; +const KEY_ENDMARK: u8 = 0xff; // invalid on utf-8. Thus, use it for preventing that any key cannot be the prefix of another key. +struct NodeHeader { + len: u32, // the len of prefix + prefix: [MaybeUninit; PREFIX_LEN], // prefix for path compression +} + +impl Debug for NodeHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unsafe { + f.debug_struct("NodeHeader") + .field("len", &self.len) + .field( + "prefix", + &self + .prefix + .get_unchecked(..min(PREFIX_LEN, self.len as usize)), + ) + .finish() + } + } +} + +impl Default for NodeHeader { + fn default() -> Self { + Self { + len: 0, + prefix: uninit_array![u8; PREFIX_LEN], + } + } +} + +/// the child node type +/// This is used for bitflag on child pointer. +const NODETYPE_MASK: usize = 0b111; +#[repr(usize)] +#[derive(Debug, PartialEq)] +enum NodeType { + Value = 0b000, + Node4 = 0b001, + Node16 = 0b010, + Node48 = 0b011, + Node256 = 0b100, +} diff --git a/src/btree/mod.rs b/src/btree/mod.rs index b7c03e9..091f6e9 100644 --- a/src/btree/mod.rs +++ b/src/btree/mod.rs @@ -4,41 +4,11 @@ use std::ptr; use std::{cmp::Ordering, mem, ptr::NonNull}; use crate::map::SequentialMap; +use crate::util::{slice_insert, slice_remove}; const B_MAX_NODES: usize = 11; const B_MID_INDEX: usize = B_MAX_NODES / 2; -/// insert value into [T], which has one empty area on last. -/// ex) insert C at 1 into [A, B, uninit] => [A, C, B] -unsafe fn slice_insert(ptr: &mut [T], index: usize, value: T) { - let size = ptr.len(); - debug_assert!(size > index); - - let ptr = ptr.as_mut_ptr(); - - if size > index + 1 { - ptr::copy(ptr.add(index), ptr.add(index + 1), size - index - 1); - } - - ptr::write(ptr.add(index), value); -} - -/// remove value from [T] and remain last area without any init -/// ex) remove at 1 from [A, B, C] => [A, C, C(but you should not access here)] -unsafe fn slice_remove(ptr: &mut [T], index: usize) -> T { - let size = ptr.len(); - debug_assert!(size > index); - - let ptr = ptr.as_mut_ptr(); - let value = ptr::read(ptr.add(index)); - - if size > index + 1 { - ptr::copy(ptr.add(index + 1), ptr.add(index), size - index - 1); - } - - value -} - struct Node { size: usize, depth: usize, diff --git a/src/lib.rs b/src/lib.rs index d8a6626..d8b453d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod avltree; pub mod btree; +pub mod art; pub mod linkedlist; pub mod lock; pub mod map; diff --git a/src/linkedlist/mod.rs b/src/linkedlist/mod.rs index 7fd37e0..f7c8f88 100644 --- a/src/linkedlist/mod.rs +++ b/src/linkedlist/mod.rs @@ -1,10 +1,12 @@ use crate::map::SequentialMap; // simple sequential linked list +#[derive(Debug)] pub struct LinkedList { head: Node, // dummy node with key = Default, but the key is not considered on algorithm } +#[derive(Debug)] struct Node { key: K, value: V, diff --git a/src/util/mod.rs b/src/util/mod.rs index cf6de98..90a1913 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,5 +1,38 @@ +use std::ptr; + pub mod random; +/// insert value into [T], which has one empty area on last. +/// ex) insert C at 1 into [A, B, uninit] => [A, C, B] +pub unsafe fn slice_insert(ptr: &mut [T], index: usize, value: T) { + let size = ptr.len(); + debug_assert!(size > index); + + let ptr = ptr.as_mut_ptr(); + + if size > index + 1 { + ptr::copy(ptr.add(index), ptr.add(index + 1), size - index - 1); + } + + ptr::write(ptr.add(index), value); +} + +/// remove value from [T] and remain last area without any init +/// ex) remove at 1 from [A, B, C] => [A, C, C(but you should not access here)] +pub unsafe fn slice_remove(ptr: &mut [T], index: usize) -> T { + let size = ptr.len(); + debug_assert!(size > index); + + let ptr = ptr.as_mut_ptr(); + let value = ptr::read(ptr.add(index)); + + if size > index + 1 { + ptr::copy(ptr.add(index + 1), ptr.add(index), size - index - 1); + } + + value +} + #[macro_export] macro_rules! ok_or { ($e:expr, $err:expr) => {{ @@ -19,3 +52,13 @@ macro_rules! some_or { } }}; } + +#[macro_export] +macro_rules! left_or { + ($e:expr, $err:expr) => {{ + match $e { + Either::Left(l) => l, + Either::Right(_) => $err, + } + }}; +} diff --git a/src/util/random.rs b/src/util/random.rs index cb6b358..2856518 100644 --- a/src/util/random.rs +++ b/src/util/random.rs @@ -6,7 +6,7 @@ pub trait Random { } const RANDOM_STRING_MIN: usize = 0; -const RANDOM_STRING_MAX: usize = 10; +const RANDOM_STRING_MAX: usize = 128; impl Random for String { // get random string whose length is in [RANDOM_STRING_MIN, RANDOM_STRING_MAX) diff --git a/tests/art/mod.rs b/tests/art/mod.rs new file mode 100644 index 0000000..0c44756 --- /dev/null +++ b/tests/art/mod.rs @@ -0,0 +1,125 @@ +use cds::art::ART; +use cds::map::SequentialMap; +use rand::prelude::SliceRandom; +use rand::thread_rng; + +use crate::util::map::stress_sequential; + +#[test] +fn test_art() { + let mut art: ART = ART::new(); + + assert_eq!(art.insert(&"a".to_string(), 1), Ok(())); + assert_eq!(art.insert(&"ab".to_string(), 2), Ok(())); + assert_eq!(art.insert(&"ac".to_string(), 3), Ok(())); + assert_eq!(art.insert(&"ad".to_string(), 4), Ok(())); + assert_eq!(art.insert(&"acb".to_string(), 5), Ok(())); + + assert_eq!(art.lookup(&"a".to_string()), Some(&1)); + assert_eq!(art.lookup(&"ab".to_string()), Some(&2)); + assert_eq!(art.lookup(&"ac".to_string()), Some(&3)); + assert_eq!(art.lookup(&"ad".to_string()), Some(&4)); + assert_eq!(art.lookup(&"acb".to_string()), Some(&5)); + + assert_eq!(art.remove(&"a".to_string()), Ok(1)); + assert_eq!(art.remove(&"ab".to_string()), Ok(2)); + assert_eq!(art.remove(&"ac".to_string()), Ok(3)); + assert_eq!(art.remove(&"ad".to_string()), Ok(4)); + assert_eq!(art.remove(&"acb".to_string()), Ok(5)); +} + +#[test] +#[rustfmt::skip] +fn test_large_key_art() { + let mut art: ART = ART::new(); + assert_eq!(art.insert(&"1234567890".to_string(), 1), Ok(())); + assert_eq!(art.insert(&"12345678901234567890".to_string(), 2), Ok(())); + assert_eq!(art.insert(&"123456789012345678901234567890".to_string(), 3), Ok(())); + assert_eq!(art.insert(&"1234567890123456789012345678901234567890".to_string(), 4), Ok(())); + assert_eq!(art.insert(&"12345678901234567890123456789012345678901234567890".to_string(), 5), Ok(())); + assert_eq!(art.insert(&"123456789012345678901234567890123456789012345678901234567890".to_string(), 6), Ok(())); + assert_eq!(art.lookup(&"1234567890".to_string()), Some(&1)); + assert_eq!(art.lookup(&"12345678901234567890".to_string()), Some(&2)); + assert_eq!(art.lookup(&"123456789012345678901234567890".to_string()), Some(&3)); + assert_eq!(art.lookup(&"1234567890123456789012345678901234567890".to_string()), Some(&4)); + assert_eq!(art.lookup(&"12345678901234567890123456789012345678901234567890".to_string()), Some(&5)); + assert_eq!(art.lookup(&"123456789012345678901234567890123456789012345678901234567890".to_string()), Some(&6)); + assert_eq!(art.remove(&"1234567890".to_string()), Ok(1)); + assert_eq!(art.remove(&"12345678901234567890".to_string()), Ok(2)); + assert_eq!(art.remove(&"123456789012345678901234567890".to_string()), Ok(3)); + assert_eq!(art.remove(&"1234567890123456789012345678901234567890".to_string()), Ok(4)); + assert_eq!(art.remove(&"12345678901234567890123456789012345678901234567890".to_string()), Ok(5)); + assert_eq!(art.remove(&"123456789012345678901234567890123456789012345678901234567890".to_string()), Ok(6)); +} + +#[test] +#[rustfmt::skip] +fn test_split_key_art() { + let mut art: ART = ART::new(); + assert_eq!(art.insert(&"123456789012345678901234567890123456789012345678901234567890".to_string(), 6), Ok(())); + assert_eq!(art.insert(&"12345678901234567890123456789012345678901234567890".to_string(), 5), Ok(())); + assert_eq!(art.lookup(&"12345678901234567890123456789012345678901234567890".to_string()), Some(&5)); + assert_eq!(art.lookup(&"123456789012345678901234567890123456789012345678901234567890".to_string()), Some(&6)); + assert_eq!(art.insert(&"1234567890123456789012345678901234567890".to_string(), 4), Ok(())); + assert_eq!(art.insert(&"123456789012345678901234567890".to_string(), 3), Ok(())); + assert_eq!(art.insert(&"12345678901234567890".to_string(), 2), Ok(())); + assert_eq!(art.insert(&"1234567890".to_string(), 1), Ok(())); + assert_eq!(art.lookup(&"1234567890".to_string()), Some(&1)); + assert_eq!(art.lookup(&"12345678901234567890".to_string()), Some(&2)); + assert_eq!(art.lookup(&"123456789012345678901234567890".to_string()), Some(&3)); + assert_eq!(art.lookup(&"1234567890123456789012345678901234567890".to_string()), Some(&4)); + assert_eq!(art.lookup(&"12345678901234567890123456789012345678901234567890".to_string()), Some(&5)); + assert_eq!(art.lookup(&"123456789012345678901234567890123456789012345678901234567890".to_string()), Some(&6)); + assert_eq!(art.remove(&"123456789012345678901234567890123456789012345678901234567890".to_string()), Ok(6)); + assert_eq!(art.remove(&"12345678901234567890123456789012345678901234567890".to_string()), Ok(5)); + assert_eq!(art.remove(&"1234567890123456789012345678901234567890".to_string()), Ok(4)); + assert_eq!(art.remove(&"123456789012345678901234567890".to_string()), Ok(3)); + assert_eq!(art.remove(&"12345678901234567890".to_string()), Ok(2)); + assert_eq!(art.remove(&"1234567890".to_string()), Ok(1)); +} + +#[test] +fn test_extend_shrink_art() { + let mut art: ART = ART::new(); + let mut keys = Vec::new(); + + for i in '0'..'z' { + let key = "a".to_string() + &i.to_string(); + assert_eq!(art.insert(&key, i as usize), Ok(())); + keys.push(key); + + for k in &keys { + assert!(art.lookup(k).is_some(), "key: {:?}", k); + } + } + + let mut rng = thread_rng(); + keys.shuffle(&mut rng); + + let mut removed_keys = Vec::new(); + + for _ in 0..keys.len() { + let key = keys.pop().unwrap(); + assert!(art.remove(&key).is_ok()); + removed_keys.push(key); + + for k in &keys { + assert!(art.lookup(k).is_some(), "key: {:?}", k); + } + + for k in &removed_keys { + assert!(art.lookup(k).is_none(), "key: {:?}", k); + } + } +} + +#[test] +fn stress_art() { + stress_sequential::>(1_000_000); +} + +#[test] +fn debug_art() { + let art: ART = ART::new(); + art.print_debug_info(); +} diff --git a/tests/tests.rs b/tests/tests.rs index e09d14f..8659391 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,5 +1,6 @@ mod avltree; mod btree; +mod art; mod linkedlist; mod lock; mod queue; diff --git a/tests/util/map.rs b/tests/util/map.rs index 63e4b81..c1370a0 100644 --- a/tests/util/map.rs +++ b/tests/util/map.rs @@ -32,7 +32,7 @@ enum OperationType { pub fn stress_sequential(iter: u64) where K: Ord + Clone + Random + Debug, - M: SequentialMap, + M: SequentialMap + Debug, { // 10 times try to get not existing key, or return if failing let gen_not_existing_key = |rng: &mut ThreadRng, map: &BTreeMap| { @@ -129,17 +129,22 @@ where // value.unwrap() // ); assert_eq!(map.remove(&existing_key).ok(), value); - - // early stop code if the remove has any problems - // for key in ref_map.keys().collect::>() { - // assert_eq!(map.lookup(key).is_some(), true, "the key {:?} is not found.", key); - // } } } } + + // early stop code if the op has any problems + // for key in ref_map.keys().collect::>() { + // if map.lookup(key).is_none() { + // println!("before: {}", before); + // println!("after: {:?}", map); + // panic!("the key {:?} is not found.", key); + // } + // } } } +#[derive(Debug)] struct Sequentialized where K: Eq, @@ -188,7 +193,7 @@ where pub fn stress_concurrent_as_sequential(iter: u64) where K: Ord + Clone + Random + Debug, - M: ConcurrentMap, + M: ConcurrentMap + Debug, { stress_sequential::>(iter) }