diff --git a/Cargo.lock b/Cargo.lock index c2d013c..bdfda39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -458,6 +458,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -465,6 +480,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -473,6 +489,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -491,10 +535,16 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1099,6 +1149,7 @@ version = "0.1.0" dependencies = [ "cpu-time", "diskann", + "futures", "service", "system", "tokio", diff --git a/Cargo.toml b/Cargo.toml index a61a056..fed146d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ tokio = { version = "1.47.1", features = ["full"] } tonic = { version = "0.14.2", features = ["transport"] } tonic-prost = "0.14.2" prost = "0.14" +futures = "0.3" serde = { version = "1.0.210", features = ["derive", "rc"] } serde_json = "1.0" uuid = { version = "1.3", features = ["v4"] } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 6c5263d..4eca2a1 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -16,6 +16,7 @@ service = { path = "../nyas/service" } tokio.workspace = true tonic.workspace = true cpu-time = "1.0.0" +futures.workspace = true [[bin]] name = "sift10k_index" diff --git a/examples/src/index_utils.rs b/examples/src/index_utils.rs index 2051bc6..a6f4c84 100644 --- a/examples/src/index_utils.rs +++ b/examples/src/index_utils.rs @@ -3,8 +3,11 @@ use std::path::Path; use diskann::index_view::IndexView; use system::vector_data::VectorData; -use tokio::fs::File; -use tokio::io::{self, AsyncReadExt, BufReader}; +use tokio::io; + +fn env_usize(name: &str) -> Option { + std::env::var(name).ok()?.parse::().ok().filter(|v| *v > 0) +} #[derive(Debug)] pub struct SiftDataset { @@ -14,48 +17,39 @@ pub struct SiftDataset { impl SiftDataset { pub async fn from_fvecs>(path: P) -> io::Result { - let file = File::open(path).await?; - let mut reader = BufReader::new(file); - - let mut vectors = Vec::new(); - let mut dimension = 0u32; - let mut buffer = [0u8; 4]; + let bytes = tokio::fs::read(path).await?; + if bytes.len() < 4 { + return Ok(SiftDataset { dimension: 0, vectors: Vec::new() }); + } - if reader.read_exact(&mut buffer).await.is_ok() { - dimension = u32::from_le_bytes(buffer); + let dimension = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); + let record_size = 4usize + 4usize * dimension as usize; + if record_size == 0 || bytes.len() % record_size != 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid fvecs layout or truncated file", + )); + } - // Read first vector - let mut vec = vec![0f32; dimension as usize]; - for item in vec.iter_mut().take(dimension as usize) { - reader.read_exact(&mut buffer).await?; - *item = f32::from_le_bytes(buffer); + let mut vectors = Vec::with_capacity(bytes.len() / record_size); + let mut offset = 0usize; + while offset + record_size <= bytes.len() { + let dim = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()); + if dim != dimension { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Inconsistent vector dimensions", + )); } - let vec_data = VectorData::from_f32(vec); - vectors.push(vec_data); - } - loop { - match reader.read_exact(&mut buffer).await { - Ok(_) => { - let dim = u32::from_le_bytes(buffer); - if dim != dimension { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Inconsistent vector dimensions", - )); - } - - let mut vec = vec![0f32; dimension as usize]; - for item in vec.iter_mut().take(dimension as usize) { - reader.read_exact(&mut buffer).await?; - *item = f32::from_le_bytes(buffer); - } - let vec_data = VectorData::from_f32(vec); - vectors.push(vec_data); - } - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e), + let mut vec = vec![0f32; dimension as usize]; + let data_start = offset + 4; + for (i, item) in vec.iter_mut().enumerate() { + let s = data_start + i * 4; + *item = f32::from_le_bytes(bytes[s..s + 4].try_into().unwrap()); } + vectors.push(VectorData::from_f32(vec)); + offset += record_size; } Ok(SiftDataset { dimension, vectors }) @@ -63,70 +57,68 @@ impl SiftDataset { #[allow(dead_code)] pub async fn from_bvecs>(path: P) -> io::Result { - let file = File::open(path).await?; - let mut reader = BufReader::new(file); - - let mut vectors = Vec::new(); - let mut dimension = 0u32; - let mut buffer = [0u8; 4]; - - if reader.read_exact(&mut buffer).await.is_ok() { - dimension = u32::from_le_bytes(buffer); + let bytes = tokio::fs::read(path).await?; + if bytes.len() < 4 { + return Ok(SiftDataset { dimension: 0, vectors: Vec::new() }); + } - let mut vec = vec![0u8; dimension as usize]; - reader.read_exact(&mut vec).await?; - let vec = vec.iter().map(|&x| x as f32).collect::>(); - let vec_data = VectorData::from_f32(vec); - vectors.push(vec_data); + let dimension = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); + let record_size = 4usize + dimension as usize; + if record_size == 0 || bytes.len() % record_size != 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid bvecs layout or truncated file", + )); } - loop { - match reader.read_exact(&mut buffer).await { - Ok(_) => { - let dim = u32::from_le_bytes(buffer); - if dim != dimension { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Inconsistent vector dimensions", - )); - } - - let mut vec = vec![0u8; dimension as usize]; - reader.read_exact(&mut vec).await?; - let vec = vec.iter().map(|&x| x as f32).collect::>(); - let vec_data = VectorData::from_f32(vec); - vectors.push(vec_data); - } - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e), + let mut vectors = Vec::with_capacity(bytes.len() / record_size); + let mut offset = 0usize; + while offset + record_size <= bytes.len() { + let dim = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()); + if dim != dimension { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Inconsistent vector dimensions", + )); } + let data_start = offset + 4; + let vec = bytes[data_start..data_start + dimension as usize] + .iter() + .map(|&x| x as f32) + .collect::>(); + vectors.push(VectorData::from_f32(vec)); + offset += record_size; } Ok(SiftDataset { dimension, vectors }) } pub async fn from_ivecs>(path: P) -> io::Result>> { - let file = File::open(path).await?; - let mut reader = BufReader::new(file); + let bytes = tokio::fs::read(path).await?; + if bytes.is_empty() { + return Ok(Vec::new()); + } let mut results = Vec::new(); - let mut buffer = [0u8; 4]; - - loop { - match reader.read_exact(&mut buffer).await { - Ok(_) => { - let k = u32::from_le_bytes(buffer); - let mut neighbors = vec![0u32; k as usize]; - - for item in neighbors.iter_mut().take(k as usize) { - reader.read_exact(&mut buffer).await?; - *item = u32::from_le_bytes(buffer); - } - results.push(neighbors); - } - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e), + let mut offset = 0usize; + while offset + 4 <= bytes.len() { + let k = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + let record_bytes = k * 4; + if offset + record_bytes > bytes.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid ivecs layout or truncated file", + )); } + + let mut neighbors = Vec::with_capacity(k); + for i in 0..k { + let s = offset + i * 4; + neighbors.push(u32::from_le_bytes(bytes[s..s + 4].try_into().unwrap())); + } + results.push(neighbors); + offset += record_bytes; } Ok(results) @@ -143,7 +135,6 @@ pub fn compute_recall(results: &[Vec], ground_truth: &[Vec], k: usize) let truth_set: HashSet = truth.iter().take(k).copied().collect(); let matches = result.iter().take(k).filter(|&&idx| truth_set.contains(&idx)).count(); - println!("Matches: {}", matches); total_matches += matches; } @@ -154,11 +145,23 @@ pub fn compute_recall(results: &[Vec], ground_truth: &[Vec], k: usize) pub async fn compute_recall_at_k( index_view: &IndexView, query: &SiftDataset, ground_truth: &[Vec], k: usize, ) { - let mut results = Vec::new(); - for q in query.vectors.iter() { - let res = index_view.search(q, k, 128).await; - results.push(res); - } + use futures::StreamExt; + + let query_len = query.vectors.len(); + let parallelism = env_usize("NYAS_SEARCH_CONCURRENCY").unwrap_or_else(|| { + std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4).max(4) + }); + + let mut indexed_results: Vec<(usize, Vec)> = + futures::stream::iter(query.vectors.iter().enumerate()) + .map(|(i, q)| async move { (i, index_view.search(q, k, 128).await) }) + .buffer_unordered(parallelism) + .collect() + .await; + + indexed_results.sort_unstable_by_key(|(i, _)| *i); + let results: Vec> = indexed_results.into_iter().map(|(_, r)| r).collect(); + debug_assert_eq!(results.len(), query_len); let recall = compute_recall(&results, ground_truth, k); println!("Recall for k: {}: {:?}", k, recall); diff --git a/examples/src/sift1m_index.rs b/examples/src/sift1m_index.rs index 31b6636..1df6022 100644 --- a/examples/src/sift1m_index.rs +++ b/examples/src/sift1m_index.rs @@ -3,12 +3,17 @@ use std::time::{Duration, SystemTime}; use cpu_time::ProcessTime; use diskann::index_view::IndexView; +use futures::StreamExt; use system::vector_point::VectorPoint; use tokio::io; use crate::index_utils::SiftDataset; mod index_utils; +fn env_usize(name: &str) -> Option { + std::env::var(name).ok()?.parse::().ok().filter(|v| *v > 0) +} + #[tokio::main] async fn main() -> io::Result<()> { let base_folder = "examples/data/sift"; @@ -31,9 +36,27 @@ async fn main() -> io::Result<()> { let path = Path::new(index_name); if !path.exists() { - for (index, vector) in base.vectors.iter().enumerate() { - index_view.insert(&VectorPoint::new(index as u32, vector.clone())).await.unwrap(); - } + println!("Starting insertion of {} points...", base.vectors.len()); + + let insert_parallelism = env_usize("NYAS_INSERT_CONCURRENCY").unwrap_or_else(|| { + std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4).max(4) + }); + println!("Insert concurrency: {}", insert_parallelism); + let index_view_ref = &index_view; + + futures::stream::iter(base.vectors.iter().enumerate()) + .for_each_concurrent(insert_parallelism, |(index, vector)| async move { + let point = VectorPoint::new(index as u32, vector.clone()); + index_view_ref.insert(&point).await.unwrap(); + }) + .await; + + println!("Finished insertion."); + + println!("Starting final streaming merge to disk..."); + let merge_start = SystemTime::now(); + index_view.streaming_merge().await.expect("Failed to perform final streaming merge"); + println!("Final streaming merge took: {:?}", merge_start.elapsed().unwrap()); } let index_cpu_time = start_cpu.elapsed(); diff --git a/nyas/diskann/Cargo.toml b/nyas/diskann/Cargo.toml index e6f2b10..5149659 100644 --- a/nyas/diskann/Cargo.toml +++ b/nyas/diskann/Cargo.toml @@ -21,3 +21,11 @@ serde.workspace = true rkyv.workspace = true tracing.workspace = true system = { path = "../system" } + +[[bench]] +name = "sift10k_index_10k" +harness = false + +[[bench]] +name = "sift10k_search_10k" +harness = false diff --git a/nyas/diskann/benches/sift10k_index_10k.rs b/nyas/diskann/benches/sift10k_index_10k.rs new file mode 100644 index 0000000..4f51807 --- /dev/null +++ b/nyas/diskann/benches/sift10k_index_10k.rs @@ -0,0 +1,65 @@ +use std::fs::File; +use std::io::{self, BufReader, Read}; +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use diskann::index_view::IndexView; +use system::vector_data::VectorData; +use system::vector_point::VectorPoint; + +fn dataset_root() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../examples/data/siftsmall") +} + +fn read_fvecs(path: &Path) -> io::Result> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + let mut vectors = Vec::new(); + + loop { + let mut dim_bytes = [0u8; 4]; + match reader.read_exact(&mut dim_bytes) { + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e), + } + + let dim = u32::from_le_bytes(dim_bytes) as usize; + let mut vec = vec![0f32; dim]; + let mut buf = [0u8; 4]; + for item in &mut vec { + reader.read_exact(&mut buf)?; + *item = f32::from_le_bytes(buf); + } + vectors.push(VectorData::from_f32(vec)); + } + + Ok(vectors) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + let data_root = dataset_root(); + let base = read_fvecs(&data_root.join("siftsmall_base.fvecs"))?; + + let run_id = + SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::from_secs(0)).as_millis(); + let index_name = format!("sift10k_index_only_{}", run_id); + let index_view = IndexView::new(&index_name).await?; + + let start = Instant::now(); + for (id, vector) in base.iter().enumerate() { + index_view + .insert(&VectorPoint::new(id as u32, vector.clone())) + .await + .map_err(io::Error::other)?; + } + let elapsed = start.elapsed(); + + println!("DiskANN SIFT10K Index Benchmark"); + println!("base vectors: {}", base.len()); + println!("indexing wall time: {:?}", elapsed); + + let _ = std::fs::remove_dir_all(&index_name); + Ok(()) +} diff --git a/nyas/diskann/benches/sift10k_search_10k.rs b/nyas/diskann/benches/sift10k_search_10k.rs new file mode 100644 index 0000000..dd6e2eb --- /dev/null +++ b/nyas/diskann/benches/sift10k_search_10k.rs @@ -0,0 +1,92 @@ +use std::fs::File; +use std::io::{self, BufReader, Read}; +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use diskann::index_view::IndexView; +use system::vector_data::VectorData; +use system::vector_point::VectorPoint; + +fn dataset_root() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../examples/data/siftsmall") +} + +fn read_fvecs(path: &Path) -> io::Result> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + let mut vectors = Vec::new(); + + loop { + let mut dim_bytes = [0u8; 4]; + match reader.read_exact(&mut dim_bytes) { + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e), + } + + let dim = u32::from_le_bytes(dim_bytes) as usize; + let mut vec = vec![0f32; dim]; + let mut buf = [0u8; 4]; + for item in &mut vec { + reader.read_exact(&mut buf)?; + *item = f32::from_le_bytes(buf); + } + vectors.push(VectorData::from_f32(vec)); + } + + Ok(vectors) +} + +fn percentile(sorted: &[Duration], p: f64) -> Duration { + if sorted.is_empty() { + return Duration::ZERO; + } + let idx = ((sorted.len() - 1) as f64 * p).round() as usize; + sorted[idx] +} + +#[tokio::main] +async fn main() -> io::Result<()> { + let data_root = dataset_root(); + let base = read_fvecs(&data_root.join("siftsmall_base.fvecs"))?; + let query = read_fvecs(&data_root.join("siftsmall_query.fvecs"))?; + + let run_id = + SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::from_secs(0)).as_millis(); + let index_name = format!("sift10k_search_only_{}", run_id); + let index_view = IndexView::new(&index_name).await?; + + for (id, vector) in base.iter().enumerate() { + index_view + .insert(&VectorPoint::new(id as u32, vector.clone())) + .await + .map_err(io::Error::other)?; + } + + let search_count = 10_000usize; + let mut latencies = Vec::with_capacity(search_count); + let start = Instant::now(); + for i in 0..search_count { + let q = &query[i % query.len()]; + let t0 = Instant::now(); + let _ = index_view.search(q, 10, 128).await; + latencies.push(t0.elapsed()); + } + let total = start.elapsed(); + latencies.sort_unstable(); + let qps = + if total.as_secs_f64() > 0.0 { search_count as f64 / total.as_secs_f64() } else { 0.0 }; + + println!("DiskANN SIFT10K Search Benchmark"); + println!("base vectors: {}", base.len()); + println!("query vectors: {}", query.len()); + println!("search count: {}", search_count); + println!("search total wall time: {:?}", total); + println!("search QPS: {:.2}", qps); + println!("search p50 latency: {:?}", percentile(&latencies, 0.50)); + println!("search p95 latency: {:?}", percentile(&latencies, 0.95)); + println!("search p99 latency: {:?}", percentile(&latencies, 0.99)); + + let _ = std::fs::remove_dir_all(&index_name); + Ok(()) +} diff --git a/nyas/diskann/src/disk_index_storage.rs b/nyas/diskann/src/disk_index_storage.rs index 977ce85..2c5f28d 100644 --- a/nyas/diskann/src/disk_index_storage.rs +++ b/nyas/diskann/src/disk_index_storage.rs @@ -227,7 +227,7 @@ impl DiskIndexStorage { metric_type: mem_index.metric, point_size: aligned_vector_size as u64, vector_per_block: vector_per_block as u64, - start_node: mem_index.start_node.unwrap_or(0), + start_node: mem_index.start_node.read().unwrap().unwrap_or(0), }; self.meta = Some(disk_meta.clone()); @@ -424,9 +424,7 @@ impl DiskIndexStorage { Ok(meta) } - pub async fn search( - &self, query: &VectorData, l: usize, visited_map: &mut HashMap, - ) { + pub async fn search(&self, query: &VectorData, l: usize, visited_map: &mut HashMap) { if self.index_file.is_none() || self.meta.is_none() { return; }; @@ -437,7 +435,7 @@ impl DiskIndexStorage { } pub async fn greedy_search_ssd( - &self, visited_map: &mut HashMap, start_node_id: u32, query: &VectorData, + &self, visited_map: &mut HashMap, start_node_id: u32, query: &VectorData, l: usize, metric_type: MetricType, ) { let mut candidates = BinaryHeap::new(); @@ -476,8 +474,8 @@ impl DiskIndexStorage { continue; }; - let point = VectorPoint::new(node.id, vector.clone()); - visited_map.insert(current.point_id, point); + let current_distance = vector.distance(query, metric_type); + visited_map.insert(current.point_id, current_distance); if visited_map.len() >= l { break; diff --git a/nyas/diskann/src/in_disk_index.rs b/nyas/diskann/src/in_disk_index.rs index 981eeeb..38b8d4a 100644 --- a/nyas/diskann/src/in_disk_index.rs +++ b/nyas/diskann/src/in_disk_index.rs @@ -54,16 +54,18 @@ impl InDiskIndex { } pub async fn insert(&self, point: &VectorPoint) -> Result<(), String> { - let mut rw_temp = self.rw_temp_index.write().await; - // let current_time = SystemTime::now(); - rw_temp.insert(point); + { + let rw_temp = self.rw_temp_index.read().await; + rw_temp.insert(point); - // Check if we need to snapshot - if rw_temp.size() >= self.max_temp_size { - drop(rw_temp); - self.snapshot_temp_index().await?; + if rw_temp.size() < self.max_temp_size { + return Ok(()); + } } + // Only reach here if we need to snapshot + self.snapshot_temp_index().await?; + Ok(()) } @@ -72,7 +74,7 @@ impl InDiskIndex { } pub async fn search(&self, query: &VectorData, k: usize, l: usize) -> Vec { - let mut visited_map: HashMap = HashMap::new(); + let mut visited_map: HashMap = HashMap::new(); // Search LTI let lti = self.lti.read().await; @@ -81,19 +83,13 @@ impl InDiskIndex { // Search RW-TempIndex let rw_temp = self.rw_temp_index.read().await; rw_temp.greedy_search_for_lti(query, l, &mut visited_map); - - println!("RW temp size {:?}", rw_temp.points.len()); // Search RO-TempIndices let ro_temps = self.ro_temp_indices.read().await; for ro_temp in ro_temps.iter() { ro_temp.greedy_search_for_lti(query, l, &mut visited_map); - println!("RO temp size {:?}", ro_temp.points.len()); } - let mut result_with_dist: Vec<_> = visited_map - .iter() - .map(|(id, point)| (*id, point.distance_to_vector(query, self.metric))) - .collect(); + let mut result_with_dist: Vec<_> = visited_map.into_iter().collect(); if result_with_dist.len() > k { let _ = result_with_dist.select_nth_unstable_by(k - 1, |a, b| { @@ -102,18 +98,15 @@ impl InDiskIndex { result_with_dist.truncate(k); } - let mut top_k: Vec = result_with_dist.iter().map(|(id, _)| *id).collect(); - println!("Length of top_k: {}, result: {:?}", top_k.len(), result_with_dist.len()); - // println!("Top K {:?}", top_k); + result_with_dist + .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Greater)); // Filter out deleted points - top_k.retain(|id| !self.delete_list.contains(id)); - - // Deduplicate and sort by distance - let unique_results: Vec<_> = - top_k.into_iter().collect::>().into_iter().collect(); - - unique_results.into_iter().take(k).collect() + result_with_dist + .into_iter() + .filter_map(|(id, _)| (!self.delete_list.contains(&id)).then_some(id)) + .take(k) + .collect() } #[allow(dead_code)] @@ -145,7 +138,7 @@ impl InDiskIndex { let snapshot = InMemIndex { graph: rw_temp.graph.clone(), points: rw_temp.points.clone(), - start_node: rw_temp.start_node, + start_node: std::sync::RwLock::new(*rw_temp.start_node.read().unwrap()), r: rw_temp.r, alpha: rw_temp.alpha, l_build: rw_temp.l_build, @@ -164,27 +157,87 @@ impl InDiskIndex { #[allow(dead_code)] pub async fn streaming_merge(&self) -> io::Result<()> { let deletes: Vec<_> = self.delete_list.iter().map(|e| *e).collect(); + let delete_set: std::collections::HashSet = deletes.iter().copied().collect(); + let mut rw_temp = self.rw_temp_index.write().await; let mut ro_temps = self.ro_temp_indices.write().await; - if ro_temps.is_empty() { + + if ro_temps.is_empty() && rw_temp.points.is_empty() { return Ok(()); } - //TODO: Improve it to avoid combine - let mut combined = InMemIndex::new(self.r, self.alpha, self.l_build, self.metric); + // Fast path: no deletes and only one staged index to flush. + if delete_set.is_empty() { + if ro_temps.is_empty() { + let mut lti = self.lti.write().await; + lti.insert(&rw_temp).await?; + drop(lti); + + *rw_temp = InMemIndex::new(self.r, self.alpha, self.l_build, self.metric); + self.delete_list.clear(); + return Ok(()); + } + + if ro_temps.len() == 1 && rw_temp.points.is_empty() { + let mut lti = self.lti.write().await; + lti.insert(&ro_temps[0]).await?; + drop(lti); + + ro_temps.clear(); + self.delete_list.clear(); + return Ok(()); + } + } + + println!("Starting streaming merge of {} RO indices and RW index...", ro_temps.len()); + + // Deduplicate by ID while collecting with "newest wins" semantics: + // RW (newest) first, then RO snapshots from newest to oldest. + let mut seen = std::collections::HashSet::new(); + let mut dedup_points: Vec = Vec::new(); - for ro_temp in ro_temps.iter_mut() { - ro_temp.delete(&deletes); - for entry in ro_temp.points.iter() { - combined.insert(entry.value()); + // Collect points from RW index and filter deleted IDs. + for point_ref in rw_temp.points.iter() { + let point = point_ref.value().clone(); + if !delete_set.contains(&point.id) && seen.insert(point.id) { + dedup_points.push(point); } } + // Collect points from RO indices from newest to oldest. + for ro_temp in ro_temps.iter().rev() { + for point_ref in ro_temp.points.iter() { + let point = point_ref.value().clone(); + if !delete_set.contains(&point.id) && seen.insert(point.id) { + dedup_points.push(point); + } + } + } + + if dedup_points.is_empty() { + rw_temp.points.clear(); + ro_temps.clear(); + self.delete_list.clear(); + return Ok(()); + } + + println!("Merging {} unique points...", dedup_points.len()); + + let combined = InMemIndex::new(self.r, self.alpha, self.l_build, self.metric); + + // Rebuild in parallel from deduplicated points. + use rayon::prelude::*; + dedup_points.into_par_iter().for_each(|point| { + combined.insert(&point); + }); + + println!("Persisting combined index to disk..."); let mut lti = self.lti.write().await; lti.insert(&combined).await?; - drop(lti); + println!("Streaming merge completed. Clearing temporary indices."); + *rw_temp = InMemIndex::new(self.r, self.alpha, self.l_build, self.metric); ro_temps.clear(); self.delete_list.clear(); Ok(()) @@ -194,6 +247,7 @@ impl InDiskIndex { #[cfg(test)] mod in_disk_index_test { use std::time::SystemTime; + use system::metric::MetricType; use system::vector_data::VectorData; use system::vector_point::VectorPoint; @@ -229,7 +283,7 @@ mod in_disk_index_test { let system = InDiskIndex::new("test_index_2", 32, 1.2, 50, 20, MetricType::L2) .await .expect("Expects In Disk Index creation"); - let mut in_mem_index = InMemIndex::new(32, 1.2, 50, MetricType::L2); + let in_mem_index = InMemIndex::new(32, 1.2, 50, MetricType::L2); for i in 0..50 { let vector = VectorData::from_f32(vec![i as f32, (i * 2) as f32, i as f32 / 2.0]); diff --git a/nyas/diskann/src/in_mem_index.rs b/nyas/diskann/src/in_mem_index.rs index b22229c..751a943 100644 --- a/nyas/diskann/src/in_mem_index.rs +++ b/nyas/diskann/src/in_mem_index.rs @@ -1,5 +1,5 @@ use std::collections::{BinaryHeap, HashMap, HashSet}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use dashmap::DashMap; use rayon::prelude::*; @@ -15,7 +15,7 @@ pub struct InMemIndex { /// Points stored in the index pub points: DashMap, /// Start node for navigation - pub start_node: Option, + pub start_node: RwLock>, /// Maximum out-degree pub r: usize, /// Alpha parameter for RNG property @@ -34,7 +34,7 @@ impl InMemIndex { InMemIndex { graph: DashMap::new(), points: DashMap::new(), - start_node: None, + start_node: RwLock::new(None), r, alpha, l_build, @@ -43,21 +43,27 @@ impl InMemIndex { } } - pub fn insert(&mut self, point: &VectorPoint) { + pub fn insert(&self, point: &VectorPoint) { debug_assert!(self.alpha > 1.0); self.points.insert(point.id, point.clone()); self.locks.insert(point.id, Arc::new(Mutex::new(()))); + + let start_node = *self.start_node.read().unwrap(); + //Let's initialize the graph with the first point - if self.start_node.is_none() { - self.start_node = Some(point.id); - self.graph.insert(point.id, Vec::new()); - return; //Just inserted first node, so no need to continue + if start_node.is_none() { + let mut start_lock = self.start_node.write().unwrap(); + if start_lock.is_none() { + *start_lock = Some(point.id); + self.graph.insert(point.id, Vec::new()); + return; + } } + let start_id = self.start_node.read().unwrap().unwrap(); + // Run greedy search to find candidates - debug_assert!(self.start_node.is_some()); - let (_, visited) = - self.greedy_search(self.start_node.unwrap(), &point.vector, 1, self.l_build); + let (_, visited) = self.greedy_search(start_id, &point.vector, 1, self.l_build); // Prune the visited point to get out-neighbors for the new point let out_neighbors = self.robust_prune(point.id, visited, self.alpha); @@ -170,9 +176,12 @@ impl InMemIndex { }; // Step 1: V ← (V + N_out(p)) - {p} - let mut seen = HashSet::new(); + let mut seen = HashSet::with_capacity(candidates.len() + self.r); seen.insert(point_id); candidates.retain(|id| *id != point_id); + for id in &candidates { + seen.insert(*id); + } if let Some(neighbors) = self.graph.get(&point_id) { for n in neighbors.value() { if seen.insert(*n) { @@ -181,34 +190,49 @@ impl InMemIndex { } } - let mut candidate_data: Vec = candidates - .into_par_iter() - .filter_map(|id| { - self.points.get(&id).map(|p| SearchCandidate { - point_id: id, - distance: point.distance(&p, self.metric), + // Optimization: Use sequential iterator for small candidate sets to avoid rayon overhead + let mut candidate_data: Vec = if candidates.len() < 500 { + candidates + .into_iter() + .filter_map(|id| { + self.points.get(&id).map(|p| SearchCandidate { + point_id: id, + distance: point.distance(&p, self.metric), + }) }) - }) - .collect(); + .collect() + } else { + candidates + .into_par_iter() + .filter_map(|id| { + self.points.get(&id).map(|p| SearchCandidate { + point_id: id, + distance: point.distance(&p, self.metric), + }) + }) + .collect() + }; candidate_data.sort_by(|a, b| b.cmp(a)); // Step 2 & 3: Prune candidates that don't satisfy the alpha-RNG condition let mut new_neighbors = Vec::with_capacity(self.r); - let mut new_neighbor_points: Vec = Vec::with_capacity(self.r); + let mut new_neighbor_points_ref = Vec::with_capacity(self.r); for candidate in candidate_data { let Some(point_c) = self.points.get(&candidate.point_id) else { continue; }; - let keep = !new_neighbor_points.iter().any(|point_n| { - alpha * point_n.distance(&point_c, self.metric) <= candidate.distance - }); + let keep = !new_neighbor_points_ref.iter().any( + |point_n_ref: &dashmap::mapref::one::Ref<'_, u32, VectorPoint>| { + alpha * point_n_ref.distance(&point_c, self.metric) <= candidate.distance + }, + ); if keep { new_neighbors.push(candidate.point_id); - new_neighbor_points.push(point_c.clone()); + new_neighbor_points_ref.push(point_c); } if new_neighbors.len() >= self.r { break; @@ -218,6 +242,7 @@ impl InMemIndex { new_neighbors } + #[allow(dead_code)] pub fn delete(&mut self, delete_list: &[u32]) { let delete_set: HashSet<_> = delete_list.iter().cloned().collect(); @@ -271,8 +296,11 @@ impl InMemIndex { self.locks.remove(point_id); // Update start_node if it was deleted - if self.start_node == Some(*point_id) { - self.start_node = self.graph.iter().next().map(|r| *r.key()); + { + let mut start_lock = self.start_node.write().unwrap(); + if *start_lock == Some(*point_id) { + *start_lock = self.graph.iter().next().map(|r| *r.key()); + } } } } @@ -280,8 +308,8 @@ impl InMemIndex { /// Search for k nearest neighbors #[allow(dead_code)] pub(crate) fn search(&self, query: &VectorData, k: usize, l: usize) -> Vec { - if let Some(start_id) = &self.start_node { - let (result, _) = self.greedy_search(*start_id, query, k, l); + if let Some(start_id) = *self.start_node.read().unwrap() { + let (result, _) = self.greedy_search(start_id, query, k, l); result } else { Vec::new() @@ -289,17 +317,28 @@ impl InMemIndex { } pub(crate) fn greedy_search_for_lti( - &self, query: &VectorData, l: usize, visited_map: &mut HashMap, + &self, query: &VectorData, l: usize, visited_map: &mut HashMap, ) { - if self.start_node.is_none() { - return; - } - let start_id = self.start_node.unwrap(); + let start_id = { + let lock = self.start_node.read().unwrap(); + if lock.is_none() { + return; + } + lock.unwrap() + }; let (_, visited_ids) = self.greedy_search(start_id, query, l, l); for id in visited_ids { if let Some(point) = self.points.get(&id) { - visited_map.insert(id, point.clone()); + let distance = point.distance_to_vector(query, self.metric); + visited_map + .entry(id) + .and_modify(|best| { + if distance < *best { + *best = distance; + } + }) + .or_insert(distance); } } } @@ -319,7 +358,7 @@ mod in_mem_index_test { #[test] fn test_in_mem_index() { - let mut index = InMemIndex::new(32, 1.2, 50, MetricType::L2); + let index = InMemIndex::new(32, 1.2, 50, MetricType::L2); for i in 0..100 { let vector = VectorData::from_f32(vec![i as f32, (i * 2) as f32, (i * 3) as f32]); @@ -348,7 +387,7 @@ mod in_mem_index_test { use rand::Rng; let mut rng = rand::rng(); - let mut index = InMemIndex::new($dim, $alpha, $l_build, $metric); + let index = InMemIndex::new($dim, $alpha, $l_build, $metric); for i in 0..$num_points { let vec: Vec = (0..$dim).map(|_| rng.random::()).collect(); diff --git a/nyas/diskann/src/index_view.rs b/nyas/diskann/src/index_view.rs index 6de4e27..36e3859 100644 --- a/nyas/diskann/src/index_view.rs +++ b/nyas/diskann/src/index_view.rs @@ -13,7 +13,7 @@ pub struct IndexView { impl IndexView { pub async fn new(index_name: &str) -> Result { let in_disk_index = - InDiskIndex::new(index_name, 32, 1.2, 128, 9999, MetricType::L2).await?; + InDiskIndex::new(index_name, 32, 1.2, 128, 500_000, MetricType::L2).await?; Ok(IndexView { in_disk_index }) } @@ -27,9 +27,10 @@ impl IndexView { } pub async fn search(&self, query: &VectorData, k: usize, l: usize) -> Vec { - //TODO: Implement streaming merge as background task - let result = self.in_disk_index.streaming_merge().await; - println!("Streaming merge result: {:?}", result); self.in_disk_index.search(query, k, l).await } + + pub async fn streaming_merge(&self) -> std::io::Result<()> { + self.in_disk_index.streaming_merge().await + } }