Skip to content

Commit 7274180

Browse files
committed
Added basic flow for diskann
1 parent 20cabd4 commit 7274180

13 files changed

Lines changed: 435 additions & 51 deletions

File tree

Cargo.lock

Lines changed: 39 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[workspace]
22
resolver = "3"
3-
members = [ "nyas/service", "nyas/storage", "nyas/vecd"]
3+
members = [ "nyas/diskann", "nyas/service", "nyas/storage", "nyas/vecd"]
44

55
[workspace.package]
66
authors = ["Manish Kumar <manish@neocraft.tech>"]
@@ -24,6 +24,10 @@ hnsw_rs = "0.3.2"
2424
uuid = { version = "1.3", features = ["v4"] }
2525
log = "0.4"
2626
env_logger = "0.10"
27+
memmap2 = "0.9.8"
28+
lru = "0.16.1"
29+
rand = "0.9.2"
30+
ordered-float = "5.1.0"
2731

2832
[workspace.dev-dependencies]
2933
cargo-nextest = "0.9.1"

bolt.sh

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
11
#!/usr/bin/env bash
22
set -e
33

4+
setup_rust(){
5+
echo "[INFO] Checking Rust installation..."
6+
if command -v rustc >/dev/null 2>&1; then
7+
CURRENT_VERSION=$(rustc --version | awk '{print $2}')
8+
echo "[INFO] Found Rust version $CURRENT_VERSION"
9+
if [ "$CURRENT_VERSION" != "1.90.0" ]; then
10+
echo "[INFO] Updating Rust to 1.90.0..."
11+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.90.0
12+
else
13+
echo "[OK] Rust is already 1.90.0"
14+
fi
15+
else
16+
echo "[INFO] Rust not found. Installing Rust 1.90.0..."
17+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.90.0
18+
fi
19+
20+
export PATH="$HOME/.cargo/bin:$PATH"
21+
rustc --version
22+
cargo --version
23+
24+
echo "[INFO] Installing cargo-nextest if missing..."
25+
if ! cargo nextest --version >/dev/null 2>&1; then
26+
cargo install cargo-nextest
27+
fi
28+
cargo nextest --version
29+
}
30+
31+
setup() {
32+
setup_rust
33+
}
34+
435
check() {
536
echo "[INFO] Running cargo check..."
637
cargo check
@@ -28,9 +59,10 @@ test() {
2859
}
2960

3061
help() {
31-
echo "Usage: $0 [check|build|deploy|all|help]"
62+
echo "Usage: $0 [setup|check|build|deploy|all|help]"
3263
echo
3364
echo "Commands:"
65+
echo " setup - Install Rust and cargo-nextest"
3466
echo " check - Run cargo check, fmt, and clippy"
3567
echo " build - Only build the workspace (runs check first)"
3668
echo " test - Only run tests"
@@ -41,6 +73,9 @@ help() {
4173
main() {
4274
cmd="$1"
4375
case "$cmd" in
76+
setup)
77+
setup
78+
;;
4479
check)
4580
check
4681
;;
@@ -51,6 +86,7 @@ main() {
5186
test
5287
;;
5388
all)
89+
setup
5490
check
5591
build
5692
deploy

nyas/diskann/Cargo.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "diskann"
3+
authors.workspace = true
4+
edition.workspace = true
5+
homepage.workspace = true
6+
license.workspace = true
7+
readme.workspace = true
8+
repository.workspace = true
9+
rust-version.workspace = true
10+
version.workspace = true
11+
12+
[dependencies]
13+
lru.workspace = true
14+
memmap2.workspace = true
15+
rand.workspace = true
16+
tokio.workspace = true
17+
ordered-float.workspace = true

nyas/diskann/src/index.rs

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
use lru::LruCache;
2+
use memmap2::Mmap;
3+
use ordered_float::OrderedFloat;
4+
use std::num::NonZero;
5+
use tokio::fs::File;
6+
use tokio::sync::Mutex;
7+
use tokio::try_join;
8+
9+
type NodeId = u32;
10+
11+
#[derive(Debug)]
12+
pub struct DiskANNIndex {
13+
/// Memory-mapped vector storage
14+
vec_file: Mmap,
15+
16+
/// Memory-mapped adjacency lists
17+
adj_file: Mmap,
18+
19+
/// Cache for hot nodes: NodeId -> Vector + neighbors
20+
hot_cache: Mutex<LruCache<NodeId, CachedNode>>,
21+
22+
/// Entry points for search
23+
entry_points: Vec<NodeId>,
24+
25+
/// Number of neighbors per node
26+
max_neighbors: usize,
27+
28+
/// Dimension of vectors
29+
dim: usize,
30+
}
31+
32+
#[derive(Debug, Clone)]
33+
pub struct CachedNode {
34+
pub id: NodeId,
35+
pub vector: Vec<f32>,
36+
pub neighbors: Vec<NodeId>,
37+
}
38+
39+
impl DiskANNIndex {
40+
pub async fn load(
41+
vec_path: &str,
42+
adj_path: &str,
43+
dim: usize,
44+
max_neighbors: usize,
45+
) -> Result<Self, std::io::Error> {
46+
let (vec_file, adj_file) = try_join!(File::open(vec_path), File::open(adj_path))?;
47+
let vec_mmap = unsafe { Mmap::map(&vec_file)? };
48+
let adj_mmap = unsafe { Mmap::map(&adj_file)? };
49+
50+
Ok(Self {
51+
vec_file: vec_mmap,
52+
adj_file: adj_mmap,
53+
hot_cache: Mutex::new(LruCache::new(NonZero::new(1024).unwrap())),
54+
entry_points: vec![0],
55+
max_neighbors,
56+
dim,
57+
})
58+
}
59+
60+
async fn load_node(&self, node_id: NodeId) -> CachedNode {
61+
if let Some(node) = self.hot_cache.lock().await.get(&node_id) {
62+
return node.clone();
63+
}
64+
65+
let start = (node_id as usize) * self.dim * 4;
66+
let vec_bytes = &self.vec_file[start..start + self.dim * 4];
67+
let vector = vec_bytes
68+
.chunks_exact(4)
69+
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
70+
.collect();
71+
72+
let neighbors_start = (node_id as usize) * self.max_neighbors * 4;
73+
let neighbors_bytes =
74+
&self.adj_file[neighbors_start..neighbors_start + self.max_neighbors * 4];
75+
let neighbors: Vec<NodeId> = neighbors_bytes
76+
.chunks_exact(4)
77+
.map(|b| u32::from_le_bytes(b.try_into().unwrap()))
78+
.collect();
79+
80+
let cached = CachedNode {
81+
id: node_id,
82+
vector,
83+
neighbors,
84+
};
85+
self.hot_cache.lock().await.put(node_id, cached.clone());
86+
cached
87+
}
88+
89+
fn normalize(v: &[f32]) -> Vec<f32> {
90+
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
91+
if norm == 0.0 {
92+
v.to_vec() // avoid division by zero
93+
} else {
94+
v.iter().map(|x| x / norm).collect()
95+
}
96+
}
97+
98+
pub fn distance(a: &[f32], b: &[f32]) -> f32 {
99+
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
100+
101+
let a_norm = Self::normalize(a);
102+
let b_norm = Self::normalize(b);
103+
104+
a_norm
105+
.iter()
106+
.zip(b_norm.iter())
107+
.map(|(&x, &y)| (x - y).powi(2))
108+
.sum::<f32>()
109+
.sqrt()
110+
}
111+
112+
pub async fn search(
113+
&self,
114+
query: &[f32],
115+
k: usize,
116+
beam_width: usize,
117+
) -> Vec<(NodeId, OrderedFloat<f32>)> {
118+
use std::cmp::Reverse;
119+
use std::collections::{BinaryHeap, HashSet};
120+
121+
let mut candidates: BinaryHeap<Reverse<(OrderedFloat<f32>, NodeId)>> = BinaryHeap::new();
122+
let mut visited: HashSet<NodeId> = HashSet::new();
123+
let mut top_k: BinaryHeap<Reverse<(OrderedFloat<f32>, NodeId)>> = BinaryHeap::new();
124+
125+
for &ep in &self.entry_points {
126+
let node = self.load_node(ep).await;
127+
let dist = Self::distance(query, &node.vector);
128+
candidates.push(Reverse((OrderedFloat(dist), ep)));
129+
visited.insert(ep);
130+
}
131+
132+
while let Some(Reverse((dist, node_id))) = candidates.pop() {
133+
let node = self.load_node(node_id).await;
134+
135+
if top_k.len() < k {
136+
top_k.push(Reverse((dist, node_id)));
137+
} else if dist < top_k.peek().unwrap().0.0 {
138+
top_k.pop();
139+
top_k.push(Reverse((dist, node_id)));
140+
}
141+
142+
for &nbr in &node.neighbors {
143+
if !visited.contains(&nbr) {
144+
visited.insert(nbr);
145+
let nbr_node = self.load_node(nbr).await;
146+
let nbr_dist = Self::distance(query, &nbr_node.vector);
147+
candidates.push(Reverse((OrderedFloat(nbr_dist), nbr)));
148+
if candidates.len() > beam_width {
149+
candidates.pop();
150+
}
151+
}
152+
}
153+
}
154+
155+
let mut result: Vec<_> = top_k.into_iter().map(|Reverse((d, id))| (id, d)).collect();
156+
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
157+
result
158+
}
159+
}

nyas/diskann/src/index_builder.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use tokio::fs::File;
2+
use tokio::io::{AsyncWriteExt, BufWriter};
3+
4+
pub type NodeId = u32;
5+
6+
pub async fn build_index(
7+
vectors: &Vec<&Vec<f32>>,
8+
max_neighbors: usize,
9+
vectors_path: &str,
10+
graph_path: &str,
11+
) -> std::io::Result<()> {
12+
let vec_file = File::create(vectors_path).await?;
13+
let mut vec_writer = BufWriter::new(vec_file);
14+
15+
for vec in vectors.iter() {
16+
for &f in vec.iter() {
17+
vec_writer.write_all(&f.to_le_bytes()).await?;
18+
}
19+
}
20+
vec_writer.flush().await?;
21+
22+
let graph_file = File::create(graph_path).await?;
23+
let mut graph_writer = BufWriter::new(graph_file);
24+
25+
for (i, _) in vectors.iter().enumerate() {
26+
let mut neighbors: Vec<NodeId> = (0..vectors.len() as u32)
27+
.filter(|&x| x != i as u32)
28+
.collect();
29+
neighbors.truncate(max_neighbors);
30+
31+
for &nid in &neighbors {
32+
graph_writer.write_all(&nid.to_le_bytes()).await?;
33+
}
34+
}
35+
graph_writer.flush().await?;
36+
37+
Ok(())
38+
}

nyas/diskann/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod index;
2+
pub mod index_builder;

0 commit comments

Comments
 (0)