From 187a7d3d368313c8e6d61ca6c46d2563291efe5f Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 4 Dec 2025 23:08:39 +0800 Subject: [PATCH 01/22] scaffolding retrieve algorithm. --- src/memory.rs | 1 + src/memory/algo.rs | 1 + src/memory/algo/retrieve.rs | 14 ++++++++++++++ src/memory/algo/retrieve/association.rs | 13 +++++++++++++ src/memory/algo/retrieve/deep_thought.rs | 12 ++++++++++++ src/memory/algo/retrieve/short_only.rs | 12 ++++++++++++ src/memory/algo/retrieve/similarity.rs | 13 +++++++++++++ 7 files changed, 66 insertions(+) create mode 100644 src/memory/algo.rs create mode 100644 src/memory/algo/retrieve.rs create mode 100644 src/memory/algo/retrieve/association.rs create mode 100644 src/memory/algo/retrieve/deep_thought.rs create mode 100644 src/memory/algo/retrieve/short_only.rs create mode 100644 src/memory/algo/retrieve/similarity.rs diff --git a/src/memory.rs b/src/memory.rs index 6df01bd..9787e9e 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,3 +1,4 @@ +pub mod algo; pub mod embedding; pub mod memory_cluster; pub mod memory_links; diff --git a/src/memory/algo.rs b/src/memory/algo.rs new file mode 100644 index 0000000..447f096 --- /dev/null +++ b/src/memory/algo.rs @@ -0,0 +1 @@ +pub mod retrieve; \ No newline at end of file diff --git a/src/memory/algo/retrieve.rs b/src/memory/algo/retrieve.rs new file mode 100644 index 0000000..13b6bcf --- /dev/null +++ b/src/memory/algo/retrieve.rs @@ -0,0 +1,14 @@ +use crate::memory::embedding::MemoryEmbedding; +pub mod association; +pub mod deep_thought; +pub mod short_only; +pub mod similarity; + +pub trait RetrStrategy { + type RetrRequest; //接受的查询参数类型 + fn retrieve(&self, request: Self::RetrRequest) -> Vec; //TODO:返回类型还没想好,暂定Vec,或许也可以考虑返回迭代器,看具体场景 +} + + + + diff --git a/src/memory/algo/retrieve/association.rs b/src/memory/algo/retrieve/association.rs new file mode 100644 index 0000000..842efd3 --- /dev/null +++ b/src/memory/algo/retrieve/association.rs @@ -0,0 +1,13 @@ +use super::RetrStrategy; + +//用PPR变种算法进行联想 +pub struct RetrAssociation { + max_results: usize, +} +pub struct AssociationRequest {} +impl RetrStrategy for RetrAssociation { + type RetrRequest = AssociationRequest; + fn retrieve(&self, request: Self::RetrRequest) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/deep_thought.rs b/src/memory/algo/retrieve/deep_thought.rs new file mode 100644 index 0000000..efb3630 --- /dev/null +++ b/src/memory/algo/retrieve/deep_thought.rs @@ -0,0 +1,12 @@ +use super::RetrStrategy; +// 采用 LLM进行的Plan-on-Graph +pub struct RetrDeepThought { + max_depth: usize, +} +pub struct DeepThoughtRequest {} +impl RetrStrategy for RetrDeepThought { + type RetrRequest = DeepThoughtRequest; + fn retrieve(&self, request: Self::RetrRequest) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/short_only.rs b/src/memory/algo/retrieve/short_only.rs new file mode 100644 index 0000000..c89bb8f --- /dev/null +++ b/src/memory/algo/retrieve/short_only.rs @@ -0,0 +1,12 @@ +//仅提取短期记忆策略,即仅提取滑动窗口 +use super::RetrStrategy; +pub struct RetrShortOnly { + clipping_length: Option, + include_summary: bool, +} +impl RetrStrategy for RetrShortOnly { + type RetrRequest = (); + fn retrieve(&self, _request: Self::RetrRequest) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/similarity.rs b/src/memory/algo/retrieve/similarity.rs new file mode 100644 index 0000000..539d8a2 --- /dev/null +++ b/src/memory/algo/retrieve/similarity.rs @@ -0,0 +1,13 @@ +//仅提取相似记忆策略,即仅提取相似度大于阈值的记忆片段 +use super::RetrStrategy; +pub struct RetrSimilarity { + similarity_threshold: f64, + max_results: usize, +} +pub struct SimilarityRequest {} +impl RetrStrategy for RetrSimilarity { + type RetrRequest = SimilarityRequest; + fn retrieve(&self, request: Self::RetrRequest) -> Vec { + todo!() + } +} From 8f75c664d7d45837353d3d408b8d37662acbb0e1 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 4 Dec 2025 23:27:22 +0800 Subject: [PATCH 02/22] add working memory --- src/memory/working_memory.rs | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 src/memory/working_memory.rs diff --git a/src/memory/working_memory.rs b/src/memory/working_memory.rs new file mode 100644 index 0000000..30fb4ef --- /dev/null +++ b/src/memory/working_memory.rs @@ -0,0 +1,7 @@ +use crate::memory::memory_cluster::MemoryCluster; + +//代表工作记忆,应当包含记忆子图,短期记忆(滑动窗口),记忆的提取记录等。 +// 占位,后续逐渐增加内容 +pub struct WorkingMemory { + cluster: MemoryCluster, +} From 729cda8850d0e550480a67f2246ab28816d60142 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 4 Dec 2025 23:27:32 +0800 Subject: [PATCH 03/22] detailed the scaffold --- src/memory.rs | 1 + src/memory/algo/retrieve.rs | 4 ---- src/memory/algo/retrieve/association.rs | 8 +++++++- src/memory/algo/retrieve/deep_thought.rs | 8 +++++++- src/memory/algo/retrieve/short_only.rs | 10 ++++++++-- src/memory/algo/retrieve/similarity.rs | 6 +++++- 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/memory.rs b/src/memory.rs index 9787e9e..93484ef 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -3,3 +3,4 @@ pub mod embedding; pub mod memory_cluster; pub mod memory_links; pub mod memory_note; +pub mod working_memory; diff --git a/src/memory/algo/retrieve.rs b/src/memory/algo/retrieve.rs index 13b6bcf..8bd1527 100644 --- a/src/memory/algo/retrieve.rs +++ b/src/memory/algo/retrieve.rs @@ -8,7 +8,3 @@ pub trait RetrStrategy { type RetrRequest; //接受的查询参数类型 fn retrieve(&self, request: Self::RetrRequest) -> Vec; //TODO:返回类型还没想好,暂定Vec,或许也可以考虑返回迭代器,看具体场景 } - - - - diff --git a/src/memory/algo/retrieve/association.rs b/src/memory/algo/retrieve/association.rs index 842efd3..43a3e67 100644 --- a/src/memory/algo/retrieve/association.rs +++ b/src/memory/algo/retrieve/association.rs @@ -1,10 +1,16 @@ +use std::sync::Arc; + +use crate::memory::working_memory::WorkingMemory; + use super::RetrStrategy; //用PPR变种算法进行联想 pub struct RetrAssociation { max_results: usize, } -pub struct AssociationRequest {} +pub struct AssociationRequest { + working_mem: Arc +} impl RetrStrategy for RetrAssociation { type RetrRequest = AssociationRequest; fn retrieve(&self, request: Self::RetrRequest) -> Vec { diff --git a/src/memory/algo/retrieve/deep_thought.rs b/src/memory/algo/retrieve/deep_thought.rs index efb3630..daec33c 100644 --- a/src/memory/algo/retrieve/deep_thought.rs +++ b/src/memory/algo/retrieve/deep_thought.rs @@ -1,9 +1,15 @@ +use std::sync::Arc; + +use crate::memory::working_memory::WorkingMemory; + use super::RetrStrategy; // 采用 LLM进行的Plan-on-Graph pub struct RetrDeepThought { max_depth: usize, } -pub struct DeepThoughtRequest {} +pub struct DeepThoughtRequest { + working_mem: Arc, +} impl RetrStrategy for RetrDeepThought { type RetrRequest = DeepThoughtRequest; fn retrieve(&self, request: Self::RetrRequest) -> Vec { diff --git a/src/memory/algo/retrieve/short_only.rs b/src/memory/algo/retrieve/short_only.rs index c89bb8f..5cb4255 100644 --- a/src/memory/algo/retrieve/short_only.rs +++ b/src/memory/algo/retrieve/short_only.rs @@ -1,12 +1,18 @@ +use crate::memory::working_memory::WorkingMemory; +use std::sync::Arc; + //仅提取短期记忆策略,即仅提取滑动窗口 use super::RetrStrategy; pub struct RetrShortOnly { clipping_length: Option, include_summary: bool, } +pub struct ShortOnlyRequest { + working_mem: Arc, //因为检索算法很可能需要并发执行,使用Arc而非引用确保可以Send +} impl RetrStrategy for RetrShortOnly { - type RetrRequest = (); - fn retrieve(&self, _request: Self::RetrRequest) -> Vec { + type RetrRequest = ShortOnlyRequest; + fn retrieve(&self, request: Self::RetrRequest) -> Vec { todo!() } } diff --git a/src/memory/algo/retrieve/similarity.rs b/src/memory/algo/retrieve/similarity.rs index 539d8a2..7697a19 100644 --- a/src/memory/algo/retrieve/similarity.rs +++ b/src/memory/algo/retrieve/similarity.rs @@ -1,10 +1,14 @@ //仅提取相似记忆策略,即仅提取相似度大于阈值的记忆片段 use super::RetrStrategy; +use crate::memory::working_memory::WorkingMemory; +use std::sync::Arc; pub struct RetrSimilarity { similarity_threshold: f64, max_results: usize, } -pub struct SimilarityRequest {} +pub struct SimilarityRequest { + working_mem: Arc, +} impl RetrStrategy for RetrSimilarity { type RetrRequest = SimilarityRequest; fn retrieve(&self, request: Self::RetrRequest) -> Vec { From c307d7db26bfe817db2a539390f8b149c6139718 Mon Sep 17 00:00:00 2001 From: dynamder Date: Sun, 14 Dec 2025 22:54:50 +0800 Subject: [PATCH 04/22] rustfmt --- src/memory/algo/retrieve/association.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memory/algo/retrieve/association.rs b/src/memory/algo/retrieve/association.rs index 43a3e67..b519d83 100644 --- a/src/memory/algo/retrieve/association.rs +++ b/src/memory/algo/retrieve/association.rs @@ -9,7 +9,7 @@ pub struct RetrAssociation { max_results: usize, } pub struct AssociationRequest { - working_mem: Arc + working_mem: Arc, } impl RetrStrategy for RetrAssociation { type RetrRequest = AssociationRequest; From 7c417431b49607a69313b1a6cf2f5e00d3835958 Mon Sep 17 00:00:00 2001 From: dynamder Date: Sun, 14 Dec 2025 22:58:10 +0800 Subject: [PATCH 05/22] add embedding vec placeholder --- src/memory/embedding.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memory/embedding.rs b/src/memory/embedding.rs index 096d2b0..a6e1d5a 100644 --- a/src/memory/embedding.rs +++ b/src/memory/embedding.rs @@ -2,6 +2,10 @@ use thiserror::Error; use crate::memory::memory_note::EmbedMemoryNote; +pub struct EmbeddingVector { + // Placeholder for embedding vector +} + pub trait Embeddable { fn embed(&self, model: &EmbeddingModel) -> Result; fn embed_vec(&self, model: &EmbeddingModel) -> Result; @@ -19,8 +23,8 @@ pub enum EmbeddingError { pub struct EmbeddingModel { // Placeholder for embedding model wrapper } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct MemoryEmbedding { //Placeholder for embedding holder } From 341c99d48eb6e0e79e07de7bfa84fff223eee57f Mon Sep 17 00:00:00 2001 From: dynamder Date: Sun, 14 Dec 2025 23:24:56 +0800 Subject: [PATCH 06/22] add placeholder for basic embedding calculation, detail the error type. --- src/memory/embedding.rs | 67 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/src/memory/embedding.rs b/src/memory/embedding.rs index a6e1d5a..a3c1a50 100644 --- a/src/memory/embedding.rs +++ b/src/memory/embedding.rs @@ -5,21 +5,44 @@ use crate::memory::memory_note::EmbedMemoryNote; pub struct EmbeddingVector { // Placeholder for embedding vector } +impl EmbeddingVector { + pub fn euclidean_distance(&self, other: &EmbeddingVector) -> Result { + todo!("Euclidean distance") + } + pub fn cosine_similarity(&self, other: &EmbeddingVector) -> Result { + todo!("Cosine similarity") + } + pub fn manhattan_distance(&self, other: &EmbeddingVector) -> Result { + todo!("Manhattan distance") + } +} pub trait Embeddable { - fn embed(&self, model: &EmbeddingModel) -> Result; - fn embed_vec(&self, model: &EmbeddingModel) -> Result; + fn embed(&self, model: &EmbeddingModel) -> Result; + fn embed_vec(&self, model: &EmbeddingModel) -> Result; } +type EmbedCalcResult = Result; +type EmbeddingGenResult = Result; -//Only a placeholder for now #[derive(Debug, Error)] -pub enum EmbeddingError { - #[error("Invalid input")] +pub enum EmbeddingGenError { + #[error("Invalid input")] //缺失了某些必要字段 InvalidInput, #[error("Embedding failed")] EmbeddingFailed, } +//Only a placeholder for now +#[derive(Debug, Error)] +pub enum EmbeddingCalcError { + #[error("Invalid vec")] //缺失了某些必要字段 + InvalidVec, + #[error("Shape mismatch")] //维度不匹配 + ShapeMismatch, + #[error("Invalid number value")] //数值无效,例如NaN,Inf等 + InvalidNumValue, +} + pub struct EmbeddingModel { // Placeholder for embedding model wrapper } @@ -28,3 +51,37 @@ pub struct EmbeddingModel { pub struct MemoryEmbedding { //Placeholder for embedding holder } +impl MemoryEmbedding { + pub fn euclidean_distance( + &self, + other: &MemoryEmbedding, + hyperparams: VecBlendHyperParams, + ) -> Result { + todo!("Euclidean distance") + } + pub fn cosine_similarity( + &self, + other: &MemoryEmbedding, + hyperparams: VecBlendHyperParams, + ) -> Result { + todo!("Cosine similarity") + } + pub fn manhattan_distance( + &self, + other: &MemoryEmbedding, + hyperparams: VecBlendHyperParams, + ) -> Result { + todo!("Manhattan distance") + } +} +#[derive(Debug, Clone, Copy)] +pub struct VecBlendHyperParams { + // Placeholder for vector blending hyperparameters +} +impl Default for VecBlendHyperParams { + fn default() -> Self { + VecBlendHyperParams { + // Placeholder for default values + } + } +} From 9efac24caf387a10c783e8ae96b8da369a048922 Mon Sep 17 00:00:00 2001 From: dynamder Date: Mon, 15 Dec 2025 00:44:27 +0800 Subject: [PATCH 07/22] add cache placeholder --- src/cache.rs | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/cache.rs diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..e69de29 From 828c1d6c64b320e1c3a56af827296a3155ac5c16 Mon Sep 17 00:00:00 2001 From: dynamder Date: Mon, 15 Dec 2025 00:44:43 +0800 Subject: [PATCH 08/22] add foyer dependency --- Cargo.lock | 457 +++++++++++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 1 + src/lib.rs | 1 + 3 files changed, 450 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 010a3a3..97e7ae3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -167,6 +167,12 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arg_enum_proc_macro" version = "0.3.4" @@ -275,7 +281,7 @@ checksum = "fd45deb3dbe5da5cdb8d6a670a7736d735ba65b455328440f236dfb113727a3d" dependencies = [ "Inflector", "async-graphql-parser", - "darling", + "darling 0.20.11", "proc-macro-crate", "proc-macro2", "quote", @@ -896,6 +902,12 @@ dependencies = [ "libloading", ] +[[package]] +name = "cmsketch" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7ee2cfacbd29706479902b06d75ad8f1362900836aa32799eabc7e004bfd854" + [[package]] name = "color_quant" version = "1.1.0" @@ -971,6 +983,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core_affinity" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a034b3a7b624016c6e13f5df875747cc25f884156aad2abd12b6c46797971342" +dependencies = [ + "libc", + "num_cpus", + "winapi", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -1053,14 +1076,38 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core 0.14.4", + "darling_macro 0.14.4", +] + [[package]] name = "darling" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", ] [[package]] @@ -1073,17 +1120,28 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.11.1", "syn 2.0.104", ] +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core 0.14.4", + "quote", + "syn 1.0.109", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", "quote", "syn 2.0.104", ] @@ -1151,7 +1209,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -1259,6 +1317,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + [[package]] name = "dtoa" version = "1.0.10" @@ -1417,6 +1481,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "fastant" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e825441bfb2d831c47c97d05821552db8832479f44c571b97fededbf0099c07" +dependencies = [ + "small_ctor", + "web-time", +] + [[package]] name = "fastembed" version = "5.2.0" @@ -1487,6 +1561,18 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8" +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1499,6 +1585,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1529,12 +1621,131 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8866fac38f53fc87fa3ae1b09ddd723e0482f8fa74323518b4c59df2c55a00a" +[[package]] +name = "foyer" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a31f699ce88ac9a53677ca0b1f7a3a902bf3bfae0579e16e86ddf61dee569c0" +dependencies = [ + "anyhow", + "equivalent", + "foyer-common", + "foyer-memory", + "foyer-storage", + "futures-util", + "madsim-tokio", + "mixtrics", + "pin-project", + "serde", + "tokio", + "tracing", +] + +[[package]] +name = "foyer-common" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9ea2c266c9d93ea37c3960f2d0bb625981eefd38120eb06542804bf8169a187" +dependencies = [ + "anyhow", + "bincode", + "bytes", + "cfg-if", + "itertools 0.14.0", + "madsim-tokio", + "mixtrics", + "parking_lot 0.12.4", + "pin-project", + "serde", + "tokio", + "twox-hash", +] + +[[package]] +name = "foyer-intrusive-collections" +version = "0.10.0-dev" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4fee46bea69e0596130e3210e65d3424e0ac1e6df3bde6636304bdf1ca4a3b" +dependencies = [ + "memoffset", +] + +[[package]] +name = "foyer-memory" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09941796e5f8301e82e81e0c9e7514a8443524a461f9a2177500c7525aa73723" +dependencies = [ + "anyhow", + "arc-swap", + "bitflags 2.9.1", + "cmsketch", + "equivalent", + "foyer-common", + "foyer-intrusive-collections", + "futures-util", + "hashbrown 0.16.1", + "itertools 0.14.0", + "madsim-tokio", + "mixtrics", + "parking_lot 0.12.4", + "paste", + "pin-project", + "serde", + "tokio", + "tracing", +] + +[[package]] +name = "foyer-storage" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75fc3db8b685c3eb8b13f05847436933f1f68e91b68510c615d90e9a69541c84" +dependencies = [ + "allocator-api2", + "anyhow", + "bytes", + "core_affinity", + "equivalent", + "fastant", + "flume", + "foyer-common", + "foyer-memory", + "fs4", + "futures-core", + "futures-util", + "hashbrown 0.16.1", + "io-uring", + "itertools 0.14.0", + "libc", + "lz4", + "madsim-tokio", + "parking_lot 0.12.4", + "pin-project", + "rand 0.9.1", + "serde", + "tokio", + "tracing", + "twox-hash", + "zstd", +] + [[package]] name = "fragile" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" +[[package]] +name = "fs4" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" +dependencies = [ + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "fst" version = "0.4.7" @@ -1929,7 +2140,18 @@ checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -2390,6 +2612,17 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "io-uring" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd7bddefd0a8833b88a4b68f90dae22c7450d11b354198baee3874fd811b344" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -2672,6 +2905,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + [[package]] name = "lz4-sys" version = "1.11.1+lz4-1.10.0" @@ -2704,6 +2946,61 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" +[[package]] +name = "madsim" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18351aac4194337d6ea9ffbd25b3d1540ecc0754142af1bff5ba7392d1f6f771" +dependencies = [ + "ahash 0.8.12", + "async-channel", + "async-stream", + "async-task", + "bincode", + "bytes", + "downcast-rs", + "errno", + "futures-util", + "lazy_static", + "libc", + "madsim-macros", + "naive-timer", + "panic-message", + "rand 0.8.5", + "rand_xoshiro", + "rustversion", + "serde", + "spin", + "tokio", + "tokio-util", + "toml 0.9.5", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "madsim-macros" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3d248e97b1a48826a12c3828d921e8548e714394bf17274dd0a93910dc946e1" +dependencies = [ + "darling 0.14.4", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "madsim-tokio" +version = "0.2.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d3eb2acc57c82d21d699119b859e2df70a91dbdb84734885a1e72be83bdecb5" +dependencies = [ + "madsim", + "spin", + "tokio", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2774,6 +3071,15 @@ version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "miette" version = "5.10.0" @@ -2840,6 +3146,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mixtrics" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb252c728b9d77c6ef9103f0c81524fa0a3d3b161d0a936295d7fbeff6e04c11" +dependencies = [ + "itertools 0.14.0", + "parking_lot 0.12.4", +] + [[package]] name = "mockall" version = "0.13.1" @@ -2904,6 +3220,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "naive-timer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "034a0ad7deebf0c2abcf2435950a6666c3c15ea9d8fad0c0f48efa8a7f843fed" + [[package]] name = "nalgebra" version = "0.34.0" @@ -2966,6 +3288,15 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom 0.2.16", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -3076,6 +3407,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -3303,6 +3643,12 @@ dependencies = [ "ureq 3.1.2", ] +[[package]] +name = "panic-message" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384e52fd8fbd4cbe3c317e8216260c21a0f9134de108cea8a4dd4e7e152c472d" + [[package]] name = "parking" version = "2.2.1" @@ -3940,6 +4286,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "rav1e" version = "0.7.1" @@ -4747,7 +5102,7 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -4775,6 +5130,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -4848,6 +5212,12 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" +[[package]] +name = "small_ctor" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88414a5ca1f85d82cc34471e975f0f74f6aa54c40f062efa42c0080e7f763f81" + [[package]] name = "smallvec" version = "1.15.1" @@ -4907,6 +5277,7 @@ dependencies = [ "dotenvy", "fastembed", "formatx", + "foyer", "log", "mockall", "nalgebra", @@ -4946,6 +5317,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "spm_precompiled" @@ -5027,6 +5401,12 @@ dependencies = [ "quote", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strsim" version = "0.11.1" @@ -5177,7 +5557,7 @@ dependencies = [ "sha2", "snap", "storekey", - "strsim", + "strsim 0.11.1", "subtle", "sysinfo", "tempfile", @@ -5798,6 +6178,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec 1.15.1", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -5838,6 +6244,15 @@ dependencies = [ "utf-8", ] +[[package]] +name = "twox-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" +dependencies = [ + "rand 0.9.1", +] + [[package]] name = "typenum" version = "1.18.0" @@ -6046,6 +6461,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vart" version = "0.8.1" @@ -6766,6 +7187,24 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.15+zstd.1.5.7" diff --git a/Cargo.toml b/Cargo.toml index bc21d69..c54f17b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ dotenvy = "0.15.7" tokio-util = "0.7.16" toml = "0.9.5" thiserror = "2.0.17" +foyer = { version = "0.21.1", features = ["serde"] } [profile.release] lto = true diff --git a/src/lib.rs b/src/lib.rs index bf4bf15..ab818fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ //! This lib crate is inspired from A-mem, an agentic memory system, and HippoRAG. pub mod memory; pub mod utils; +pub mod cache; \ No newline at end of file From dc9a18e7e0530295c4ae20e6dbeef1a33ad03238 Mon Sep 17 00:00:00 2001 From: dynamder Date: Mon, 15 Dec 2025 00:45:02 +0800 Subject: [PATCH 09/22] fix import error --- src/memory/memory_note.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/memory/memory_note.rs b/src/memory/memory_note.rs index f1bc202..87deaa6 100644 --- a/src/memory/memory_note.rs +++ b/src/memory/memory_note.rs @@ -10,7 +10,8 @@ mod proc_mem; use crate::memory::embedding::Embeddable; use crate::memory::embedding::MemoryEmbedding; -use super::embedding::EmbeddingError; +use super::embedding::EmbeddingCalcError; +use super::embedding::EmbeddingGenError; use super::embedding::EmbeddingModel; use super::memory_links::MemoryLink; @@ -66,10 +67,13 @@ impl MemoryNote { } } impl Embeddable for MemoryNote { - fn embed(&self, embedding_model: &EmbeddingModel) -> Result { + fn embed( + &self, + embedding_model: &EmbeddingModel, + ) -> Result { todo!("Add the embedding logic") } - fn embed_vec(&self, model: &EmbeddingModel) -> Result { + fn embed_vec(&self, model: &EmbeddingModel) -> Result { todo!("Add the embedding logic") } } From 4c907a0ba39f68345927e6e88e64cd7826ac0614 Mon Sep 17 00:00:00 2001 From: dynamder Date: Mon, 15 Dec 2025 00:45:31 +0800 Subject: [PATCH 10/22] detail the scaffolding, add cached_path strategy. --- src/memory/algo/retrieve.rs | 9 ++++++--- src/memory/algo/retrieve/association.rs | 13 +++++++++---- src/memory/algo/retrieve/cached_path.rs | 24 ++++++++++++++++++++++++ src/memory/algo/retrieve/deep_thought.rs | 13 +++++++++---- src/memory/algo/retrieve/short_only.rs | 15 ++++++++++----- src/memory/algo/retrieve/similarity.rs | 22 ++++++++++++++++------ 6 files changed, 74 insertions(+), 22 deletions(-) create mode 100644 src/memory/algo/retrieve/cached_path.rs diff --git a/src/memory/algo/retrieve.rs b/src/memory/algo/retrieve.rs index 8bd1527..1f84c30 100644 --- a/src/memory/algo/retrieve.rs +++ b/src/memory/algo/retrieve.rs @@ -1,10 +1,13 @@ -use crate::memory::embedding::MemoryEmbedding; +use crate::memory::{embedding::MemoryEmbedding, memory_note::MemoryId}; pub mod association; +pub mod cached_path; pub mod deep_thought; pub mod short_only; pub mod similarity; pub trait RetrStrategy { - type RetrRequest; //接受的查询参数类型 - fn retrieve(&self, request: Self::RetrRequest) -> Vec; //TODO:返回类型还没想好,暂定Vec,或许也可以考虑返回迭代器,看具体场景 + type Request: RetrRequest; //接受的查询参数类型 + fn retrieve(&self, request: Self::Request) -> Vec; //TODO:返回类型还没想好,暂定Vec,或许也可以考虑返回迭代器,看具体场景 } + +pub trait RetrRequest {} diff --git a/src/memory/algo/retrieve/association.rs b/src/memory/algo/retrieve/association.rs index b519d83..78b773f 100644 --- a/src/memory/algo/retrieve/association.rs +++ b/src/memory/algo/retrieve/association.rs @@ -1,19 +1,24 @@ use std::sync::Arc; -use crate::memory::working_memory::WorkingMemory; +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; use super::RetrStrategy; //用PPR变种算法进行联想 pub struct RetrAssociation { - max_results: usize, + pub max_results: usize, } pub struct AssociationRequest { working_mem: Arc, } + +impl RetrRequest for AssociationRequest {} + impl RetrStrategy for RetrAssociation { - type RetrRequest = AssociationRequest; - fn retrieve(&self, request: Self::RetrRequest) -> Vec { + type Request = AssociationRequest; + fn retrieve(&self, request: Self::Request) -> Vec { todo!() } } diff --git a/src/memory/algo/retrieve/cached_path.rs b/src/memory/algo/retrieve/cached_path.rs new file mode 100644 index 0000000..98b60fa --- /dev/null +++ b/src/memory/algo/retrieve/cached_path.rs @@ -0,0 +1,24 @@ +//采用dfs,通过边权中的记忆向量来快速扩展子图信息,详见ReMindRAG + +use crate::memory::working_memory::WorkingMemory; +use std::sync::Arc; + +use super::RetrRequest; +use super::RetrStrategy; + +pub struct RetrCachedPath { + pub max_depth: usize, // dfs的最大深度 + pub expand_threshold: f64, //计算向量与查询向量的相似度大于此值,将被扩展 +} + +pub struct CachedPathRequest { + working_mem: Arc, //计算向量与查询向量的相似度大于此值,将被扩展 +} +impl RetrRequest for CachedPathRequest {} + +impl RetrStrategy for RetrCachedPath { + type Request = CachedPathRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/deep_thought.rs b/src/memory/algo/retrieve/deep_thought.rs index daec33c..626d6b2 100644 --- a/src/memory/algo/retrieve/deep_thought.rs +++ b/src/memory/algo/retrieve/deep_thought.rs @@ -1,18 +1,23 @@ use std::sync::Arc; -use crate::memory::working_memory::WorkingMemory; +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; use super::RetrStrategy; // 采用 LLM进行的Plan-on-Graph pub struct RetrDeepThought { - max_depth: usize, + pub max_depth: usize, } pub struct DeepThoughtRequest { working_mem: Arc, } + +impl RetrRequest for DeepThoughtRequest {} + impl RetrStrategy for RetrDeepThought { - type RetrRequest = DeepThoughtRequest; - fn retrieve(&self, request: Self::RetrRequest) -> Vec { + type Request = DeepThoughtRequest; + fn retrieve(&self, request: Self::Request) -> Vec { todo!() } } diff --git a/src/memory/algo/retrieve/short_only.rs b/src/memory/algo/retrieve/short_only.rs index 5cb4255..f224b3c 100644 --- a/src/memory/algo/retrieve/short_only.rs +++ b/src/memory/algo/retrieve/short_only.rs @@ -1,18 +1,23 @@ -use crate::memory::working_memory::WorkingMemory; +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; use std::sync::Arc; //仅提取短期记忆策略,即仅提取滑动窗口 use super::RetrStrategy; pub struct RetrShortOnly { - clipping_length: Option, - include_summary: bool, + pub clipping_length: Option, + pub include_summary: bool, } pub struct ShortOnlyRequest { working_mem: Arc, //因为检索算法很可能需要并发执行,使用Arc而非引用确保可以Send } + +impl RetrRequest for ShortOnlyRequest {} + impl RetrStrategy for RetrShortOnly { - type RetrRequest = ShortOnlyRequest; - fn retrieve(&self, request: Self::RetrRequest) -> Vec { + type Request = ShortOnlyRequest; + fn retrieve(&self, request: Self::Request) -> Vec { todo!() } } diff --git a/src/memory/algo/retrieve/similarity.rs b/src/memory/algo/retrieve/similarity.rs index 7697a19..ae9f1a8 100644 --- a/src/memory/algo/retrieve/similarity.rs +++ b/src/memory/algo/retrieve/similarity.rs @@ -1,17 +1,27 @@ //仅提取相似记忆策略,即仅提取相似度大于阈值的记忆片段 use super::RetrStrategy; -use crate::memory::working_memory::WorkingMemory; +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; use std::sync::Arc; pub struct RetrSimilarity { - similarity_threshold: f64, - max_results: usize, + pub similarity_threshold: f64, + pub max_results: usize, } pub struct SimilarityRequest { working_mem: Arc, } +impl RetrRequest for SimilarityRequest {} impl RetrStrategy for RetrSimilarity { - type RetrRequest = SimilarityRequest; - fn retrieve(&self, request: Self::RetrRequest) -> Vec { - todo!() + type Request = SimilarityRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + //TODO: 减少计算量,以下只是一个最初步的实现 + let cos_similarities: Vec<(f64, MemoryId)> = vec![]; + cos_similarities + .into_iter() + .filter(|(similarity, _)| *similarity > self.similarity_threshold) + .map(|(_, id)| id) + .take(self.max_results) + .collect() } } From a6df5faea189e4e8410bba8fe289966240dbbd99 Mon Sep 17 00:00:00 2001 From: dynamder Date: Tue, 16 Dec 2025 01:34:35 +0800 Subject: [PATCH 11/22] scaffolding naive ppr algorithm. --- src/utils.rs | 3 +- src/utils/graph_algo.rs | 78 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 src/utils/graph_algo.rs diff --git a/src/utils.rs b/src/utils.rs index 717e162..853c8e3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1 +1,2 @@ -pub mod pipe; \ No newline at end of file +pub mod graph_algo; +pub mod pipe; diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs new file mode 100644 index 0000000..c2c04d7 --- /dev/null +++ b/src/utils/graph_algo.rs @@ -0,0 +1,78 @@ +use std::{collections::HashMap, hash::Hash}; + +use petgraph::{ + algo::UnitMeasure, + visit::{EdgeRef, IntoEdges, IntoNodeIdentifiers, NodeCount, NodeIndexable}, +}; + +#[track_caller] +pub fn naive_ppr( + graph: G, + damping_factor: D, + source_bias: HashMap, + nb_iter: usize, +) -> HashMap +where + G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, + D: UnitMeasure + Copy, + G::NodeId: Hash + Eq, +{ + let node_count = graph.node_count(); + if node_count == 0 { + return HashMap::new(); + } + + //检查阻尼系数 + assert!( + D::zero() <= damping_factor && damping_factor <= D::one(), + "Damping factor should be between 0 et 1." + ); + let d_node_count = D::from_usize(node_count); + + //检查个性化分布是不是一个概率分布 + let bias_sum: D = source_bias.values().copied().sum(); + assert!( + bias_sum > D::zero(), + "Personalized Source bias sum must be positive" + ); + //归一化个性化向量(初始向量) + let normalized_bias: HashMap = if bias_sum != D::one() { + source_bias + .into_iter() + .map(|(node_id, bias)| (node_id, bias / bias_sum)) + .collect() + } else { + source_bias + }; + + //图中有效的索引值,适配StableGraph(索引可能不连续) + let valid_index = graph + .node_identifiers() + .map(|node_id| graph.to_index(node_id)) + .collect::>(); + + //ppr值的存储 + //此处可能有大量内存浪费(无效的索引值占位),考虑到工作记忆子图不会过于频繁释放和加载,这个内存开销应该是可以接受的 + let mut ppr_ranks = vec![D::zero(); graph.node_bound()]; + let mut out_degrees = vec![D::zero(); graph.node_bound()]; + + //初始化PPR向量 + valid_index.iter().for_each(|&valid_index| { + ppr_ranks[valid_index] = D::one() / d_node_count //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 + }); + + //预计算每个节点的出度 + graph.node_identifiers().for_each(|node_id| { + out_degrees[graph.to_index(node_id)] = graph.edges(node_id).map(|_| D::one()).sum(); + }); + + for _ in 0..nb_iter { + todo!("Implement the PPR algorithm") + } + + //返回PPR向量,HashMap形式 + graph + .node_identifiers() + .map(|node_id| (node_id, ppr_ranks[graph.to_index(node_id)])) + .collect() +} From 6f9841962cda386de7435e7349ee6ae2c83299b4 Mon Sep 17 00:00:00 2001 From: dynamder Date: Tue, 16 Dec 2025 01:48:10 +0800 Subject: [PATCH 12/22] modify the way to init PPR vector, from 1/n to personalized vector itself. --- src/utils/graph_algo.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs index c2c04d7..e36321f 100644 --- a/src/utils/graph_algo.rs +++ b/src/utils/graph_algo.rs @@ -56,9 +56,9 @@ where let mut ppr_ranks = vec![D::zero(); graph.node_bound()]; let mut out_degrees = vec![D::zero(); graph.node_bound()]; - //初始化PPR向量 - valid_index.iter().for_each(|&valid_index| { - ppr_ranks[valid_index] = D::one() / d_node_count //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 + //使用个性化向量,初始化PPR值向量,由于源节点有向量相似性取top-k提供(k通常不大),这样初始化通常可以加快收敛速度 + normalized_bias.iter().for_each(|(&node_id, &bias)| { + ppr_ranks[graph.to_index(node_id)] = bias; //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 }); //预计算每个节点的出度 From 37648ca3efef562310fdcf448e2b4b6a70f3e422 Mon Sep 17 00:00:00 2001 From: dynamder Date: Tue, 16 Dec 2025 12:19:32 +0800 Subject: [PATCH 13/22] sketch the naive_ppr algo, not tested. --- src/utils/graph_algo.rs | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs index e36321f..d694ab4 100644 --- a/src/utils/graph_algo.rs +++ b/src/utils/graph_algo.rs @@ -67,7 +67,40 @@ where }); for _ in 0..nb_iter { - todo!("Implement the PPR algorithm") + let ppr_vec_i = valid_index + .iter() + .map(|&computing_idx| { + let iter_ppr = valid_index + .iter() + .map(|&idx| { + //计算每个节点的出度 + let mut out_edges = graph.edges(graph.from_index(idx)); + + //拆分标准PPR公式为整体求和形式,便于编写和计算,以及对可能的优化更友好 + if out_edges.any(|e| e.target() == graph.from_index(computing_idx)) { + damping_factor * ppr_ranks[idx] / out_degrees[idx] + } else if out_degrees[idx] == D::zero() { + damping_factor + * ppr_ranks[idx] + * normalized_bias[&graph.from_index(computing_idx)] + } else { + (D::one() - damping_factor) + * ppr_ranks[idx] + * normalized_bias[&graph.from_index(computing_idx)] + } + }) + .sum::(); + (computing_idx, iter_ppr) + }) + .collect::>(); + + // 归一化PPR值,确保数值稳定,总和为1 + + let sum = ppr_vec_i.iter().map(|(_, ppr)| *ppr).sum::(); + + ppr_vec_i.iter().for_each(|&(idx, ppr)| { + ppr_ranks[idx] = ppr / sum; + }); } //返回PPR向量,HashMap形式 From e3c014d2f44c590385993fcc94d65b09db53ad71 Mon Sep 17 00:00:00 2001 From: dynamder Date: Wed, 17 Dec 2025 00:38:13 +0800 Subject: [PATCH 14/22] fix some bugs in naive_ppr, pass the first unit test. --- src/utils/graph_algo.rs | 120 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 7 deletions(-) diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs index d694ab4..85297b1 100644 --- a/src/utils/graph_algo.rs +++ b/src/utils/graph_algo.rs @@ -5,6 +5,10 @@ use petgraph::{ visit::{EdgeRef, IntoEdges, IntoNodeIdentifiers, NodeCount, NodeIndexable}, }; +///PPR: ppr_s = dampling_factor * P * ppr_s + (1-damping_factor) * personalized_vec, P为转移矩阵 +/// 对无出度的节点,采取与source_bias中的节点建立连接 +/// 必须保证source_bias的key是有效的NodeId, 否则会得到不正确的结果 +// 由于NodeId会由MemoryCluster提供,这不会造成额外的检查负担 #[track_caller] pub fn naive_ppr( graph: G, @@ -80,13 +84,19 @@ where if out_edges.any(|e| e.target() == graph.from_index(computing_idx)) { damping_factor * ppr_ranks[idx] / out_degrees[idx] } else if out_degrees[idx] == D::zero() { - damping_factor - * ppr_ranks[idx] - * normalized_bias[&graph.from_index(computing_idx)] + normalized_bias + .get(&graph.from_index(computing_idx)) + .map(|personal_bias| { + damping_factor * ppr_ranks[idx] * *personal_bias + }) + .unwrap_or(D::zero()) } else { - (D::one() - damping_factor) - * ppr_ranks[idx] - * normalized_bias[&graph.from_index(computing_idx)] + normalized_bias + .get(&graph.from_index(computing_idx)) + .map(|personal_bias| { + (D::one() - damping_factor) * ppr_ranks[idx] * *personal_bias + }) + .unwrap_or(D::zero()) } }) .sum::(); @@ -103,9 +113,105 @@ where }); } + //最终归一化 + let sum = ppr_ranks.iter().map(|ppr| *ppr).sum::(); + //返回PPR向量,HashMap形式 graph .node_identifiers() - .map(|node_id| (node_id, ppr_ranks[graph.to_index(node_id)])) + .map(|node_id| (node_id, ppr_ranks[graph.to_index(node_id)] / sum)) .collect() } + +#[cfg(test)] +mod test { + + use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph}; + + use super::*; + fn relative_error(actual: f64, expected: f64) -> f64 { + if expected.abs() < f64::EPSILON && actual.abs() < f64::EPSILON { + 0.0 + } else { + let diff = (actual - expected).abs(); + let denominator = expected.abs().max(actual.abs()).max(f64::EPSILON); + diff / denominator + } + } + + fn test_toy_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let a = graph.add_node("A".to_string()); + let b = graph.add_node("B".to_string()); + //制造索引空洞 + graph.remove_node(b); + let b = graph.add_node("B".to_string()); + let c = graph.add_node("C".to_string()); + let d = graph.add_node("D".to_string()); + + graph.add_edge(a, b, 1.0); + graph.add_edge(a, c, 1.0); + graph.add_edge(b, c, 1.0); + graph.add_edge(c, d, 1.0); + + (graph, vec![a, b, c, d]) + } + fn toy_graph_with_init_a() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.851652742, 0.06387396045, 0.07345504972, 0.01101824785]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + fn toy_graph_with_init_b() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + fn toy_graph_with_init_ab() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + #[test] + fn ppr_toy_graph_init_a() { + let (graph, true_ans, indexes) = toy_graph_with_init_a(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert_eq!(ans_sum, 1.0); + + let relative_error = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + (actual - expected).abs() + }) + .sum::(); + + assert!( + relative_error < 0.005, + "failed with relative err {}, whole ppr_vec is : {:?}, but it should be : {:?}", + relative_error, + ppr_ans, + true_ans + ) + } +} From f38070c427b82c682f4432a5ada5777f8b0a27f0 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 18 Dec 2025 00:48:53 +0800 Subject: [PATCH 15/22] fix mathematical error in naive_ppr, pass 3 unit tests. --- src/utils/graph_algo.rs | 141 +++++++++++++++++++++++++++++----------- 1 file changed, 104 insertions(+), 37 deletions(-) diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs index 85297b1..2ce02c6 100644 --- a/src/utils/graph_algo.rs +++ b/src/utils/graph_algo.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, hash::Hash}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; use petgraph::{ algo::UnitMeasure, @@ -13,7 +13,7 @@ use petgraph::{ pub fn naive_ppr( graph: G, damping_factor: D, - source_bias: HashMap, + personalized_vec: HashMap, nb_iter: usize, ) -> HashMap where @@ -31,22 +31,22 @@ where D::zero() <= damping_factor && damping_factor <= D::one(), "Damping factor should be between 0 et 1." ); - let d_node_count = D::from_usize(node_count); //检查个性化分布是不是一个概率分布 - let bias_sum: D = source_bias.values().copied().sum(); + let personalized_sum: D = personalized_vec.values().copied().sum(); assert!( - bias_sum > D::zero(), + personalized_sum > D::zero(), "Personalized Source bias sum must be positive" ); + //归一化个性化向量(初始向量) - let normalized_bias: HashMap = if bias_sum != D::one() { - source_bias + let normalized_personalized_vec: HashMap = if personalized_sum != D::one() { + personalized_vec .into_iter() - .map(|(node_id, bias)| (node_id, bias / bias_sum)) + .map(|(node_id, bias)| (node_id, bias / personalized_sum)) .collect() } else { - source_bias + personalized_vec }; //图中有效的索引值,适配StableGraph(索引可能不连续) @@ -61,56 +61,67 @@ where let mut out_degrees = vec![D::zero(); graph.node_bound()]; //使用个性化向量,初始化PPR值向量,由于源节点有向量相似性取top-k提供(k通常不大),这样初始化通常可以加快收敛速度 - normalized_bias.iter().for_each(|(&node_id, &bias)| { - ppr_ranks[graph.to_index(node_id)] = bias; //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 - }); + normalized_personalized_vec + .iter() + .for_each(|(&node_id, &bias)| { + ppr_ranks[graph.to_index(node_id)] = bias; //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 + }); + let normalized_bias_len = normalized_personalized_vec.len(); + //println!("normalized_bias: {:?}", normalized_bias); //预计算每个节点的出度 graph.node_identifiers().for_each(|node_id| { out_degrees[graph.to_index(node_id)] = graph.edges(node_id).map(|_| D::one()).sum(); }); + //println!("out_degrees: {:?}", out_degrees); - for _ in 0..nb_iter { + for i in 0..nb_iter { let ppr_vec_i = valid_index .iter() .map(|&computing_idx| { let iter_ppr = valid_index .iter() .map(|&idx| { - //计算每个节点的出度 + //找到每个节点的出边 let mut out_edges = graph.edges(graph.from_index(idx)); - //拆分标准PPR公式为整体求和形式,便于编写和计算,以及对可能的优化更友好 + //游走部分的计算,对于无出度节点,默认其连接至所有个性化向量中不为0的节点 if out_edges.any(|e| e.target() == graph.from_index(computing_idx)) { damping_factor * ppr_ranks[idx] / out_degrees[idx] } else if out_degrees[idx] == D::zero() { - normalized_bias + normalized_personalized_vec .get(&graph.from_index(computing_idx)) - .map(|personal_bias| { - damping_factor * ppr_ranks[idx] * *personal_bias + .map(|_| { + damping_factor * ppr_ranks[idx] + / D::from_usize(normalized_bias_len) }) .unwrap_or(D::zero()) } else { - normalized_bias - .get(&graph.from_index(computing_idx)) - .map(|personal_bias| { - (D::one() - damping_factor) * ppr_ranks[idx] * *personal_bias - }) - .unwrap_or(D::zero()) + D::zero() } }) .sum::(); - (computing_idx, iter_ppr) + + //随机重启部分计算 + let random_back_part = if let Some(per_i) = + normalized_personalized_vec.get(&graph.from_index(computing_idx)) + { + (D::one() - damping_factor) * *per_i + } else { + D::zero() + }; + + (computing_idx, iter_ppr + random_back_part) }) .collect::>(); // 归一化PPR值,确保数值稳定,总和为1 - let sum = ppr_vec_i.iter().map(|(_, ppr)| *ppr).sum::(); ppr_vec_i.iter().for_each(|&(idx, ppr)| { ppr_ranks[idx] = ppr / sum; }); + //println!("iteration {i}: PPR values: {:?}", ppr_ranks); } //最终归一化 @@ -129,13 +140,12 @@ mod test { use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph}; use super::*; - fn relative_error(actual: f64, expected: f64) -> f64 { + fn diff(actual: f64, expected: f64) -> f64 { if expected.abs() < f64::EPSILON && actual.abs() < f64::EPSILON { 0.0 } else { let diff = (actual - expected).abs(); - let denominator = expected.abs().max(actual.abs()).max(f64::EPSILON); - diff / denominator + diff } } @@ -172,7 +182,7 @@ mod test { Vec>, ) { let (graph, indexes) = test_toy_graph(); - let ans_vec: Vec = vec![]; + let ans_vec: Vec = vec![0.0, 0.852878432, 0.1279320211, 0.01918954688]; let ans = indexes.iter().copied().zip(ans_vec).collect(); (graph, ans, indexes) } @@ -182,7 +192,7 @@ mod test { Vec>, ) { let (graph, indexes) = test_toy_graph(); - let ans_vec: Vec = vec![]; + let ans_vec: Vec = vec![0.4261326137, 0.4580925718, 0.1006738318, 0.00510098267]; let ans = indexes.iter().copied().zip(ans_vec).collect(); (graph, ans, indexes) } @@ -194,22 +204,79 @@ mod test { let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); let ans_sum = ppr_ans.values().copied().sum::(); - assert_eq!(ans_sum, 1.0); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_toy_graph_init_b() { + let (graph, true_ans, indexes) = toy_graph_with_init_b(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[1], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_toy_graph_init_ab() { + let (graph, true_ans, indexes) = toy_graph_with_init_ab(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], 1.0); + source_bias.insert(indexes[1], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); - let relative_error = 0.25 + let avg_diff = 0.25 * indexes .iter() .map(|idx| { let actual = ppr_ans[idx]; let expected = true_ans[idx]; - (actual - expected).abs() + diff(actual, expected) }) .sum::(); assert!( - relative_error < 0.005, - "failed with relative err {}, whole ppr_vec is : {:?}, but it should be : {:?}", - relative_error, + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, ppr_ans, true_ans ) From f80fa550515dd370ec018b937dd3daa5b5815ff7 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 18 Dec 2025 01:10:00 +0800 Subject: [PATCH 16/22] add pressure test. --- src/utils/graph_algo.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs index 2ce02c6..9939f69 100644 --- a/src/utils/graph_algo.rs +++ b/src/utils/graph_algo.rs @@ -137,6 +137,7 @@ where #[cfg(test)] mod test { + use mockall::predicate::float; use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph}; use super::*; @@ -148,6 +149,23 @@ mod test { diff } } + fn pressure_large_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + for i in 0..5000 { + let mut node = graph.add_node("".to_string()); + if i % 2 == 0 || i % 7 == 0 { + graph.remove_node(node); + node = graph.add_node("".to_string()); + } + nodes.push(node); + graph.add_edge(node, node, 1.0); + nodes.iter().for_each(|idx| { + graph.add_edge(node, *idx, 1.0); + }); + } + (graph, nodes) + } fn test_toy_graph() -> (StableDiGraph, Vec>) { let mut graph = StableDiGraph::new(); @@ -281,4 +299,16 @@ mod test { true_ans ) } + #[test] + fn pressure_large_graph_test() { + let (graph, nodes) = pressure_large_graph(); + let mut source_bias = HashMap::new(); + nodes.iter().take(10).for_each(|idx| { + source_bias.insert(*idx, graph.to_index(*idx) as f64); + }); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + } } From 4e8a299abcabfc5f14c3c4de7f5e1e05b3334622 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 18 Dec 2025 17:33:14 +0800 Subject: [PATCH 17/22] add benches --- Cargo.toml | 7 + benches/ppr.rs | 162 ++++++++++++++++++++ src/utils/graph_algo.rs | 316 +--------------------------------------- 3 files changed, 171 insertions(+), 314 deletions(-) create mode 100644 benches/ppr.rs diff --git a/Cargo.toml b/Cargo.toml index c54f17b..290e7e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,3 +39,10 @@ strip = true opt-level = 3 panic = 'abort' codegen-units = 1 + +[dev-dependencies] +criterion = "0.8.1" + +[[bench]] +name = "ppr" +harness = false diff --git a/benches/ppr.rs b/benches/ppr.rs new file mode 100644 index 0000000..7ed243e --- /dev/null +++ b/benches/ppr.rs @@ -0,0 +1,162 @@ +use criterion::{ + BenchmarkGroup, Criterion, SamplingMode, black_box, criterion_group, criterion_main, +}; +use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph, visit::NodeIndexable}; +use soul_mem::utils::graph_algo::ppr::naive_ppr; +use std::collections::HashMap; + +/// 创建压力测试图的辅助函数(在benchmark计时范围外使用) +fn pressure_large_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + for i in 0..500 { + let mut node = graph.add_node("".to_string()); + if i % 2 == 0 || i % 7 == 0 { + graph.remove_node(node); + node = graph.add_node("".to_string()); + } + nodes.push(node); + graph.add_edge(node, node, 1.0); + nodes.iter().for_each(|idx| { + graph.add_edge(node, *idx, 1.0); + }); + } + (graph, nodes) +} + +/// 准备测试数据(在benchmark外执行) +fn prepare_test_data() -> (StableDiGraph, HashMap, f64>) { + let (graph, nodes) = pressure_large_graph(); + let mut source_bias = HashMap::new(); + nodes.iter().take(10).for_each(|idx| { + source_bias.insert(*idx, graph.to_index(*idx) as f64); + }); + (graph, source_bias) +} + +/// 专门测试naive_ppr函数性能的benchmark +fn bench_naive_ppr_function(c: &mut Criterion) { + // 准备测试数据(不计时) + let (graph, source_bias) = prepare_test_data(); + + // 设置采样次数为10 + let mut group = c.benchmark_group("naive_ppr_basic"); + group.sample_size(10); // 设置采样次数为10 + group.sampling_mode(SamplingMode::Flat); + + // Benchmark 1: 基础性能测试 + group.bench_function("basic", |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(0.15_f64), + black_box(source_bias.clone()), + black_box(15), + ); + black_box(result); + }); + }); + + // Benchmark 2: 测试不同阻尼因子 + group.bench_function("damping_high", |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(0.85_f64), // 高阻尼因子 + black_box(source_bias.clone()), + black_box(15), + ); + black_box(result); + }); + }); + + // Benchmark 3: 测试不同迭代次数 + group.bench_function("iterations_50", |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(0.15_f64), + black_box(source_bias.clone()), + black_box(50), // 更多迭代次数 + ); + black_box(result); + }); + }); + + // Benchmark 4: 测试更少的迭代次数 + group.bench_function("iterations_5", |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(0.15_f64), + black_box(source_bias.clone()), + black_box(5), // 更少迭代次数 + ); + black_box(result); + }); + }); + + // Benchmark 5: 测试中等阻尼因子 + group.bench_function("damping_medium", |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(0.5_f64), // 中等阻尼因子 + black_box(source_bias.clone()), + black_box(15), + ); + black_box(result); + }); + }); + + group.finish(); +} + +/// 参数化benchmark组,测试不同参数组合 +fn bench_naive_ppr_parameterized(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + // 设置采样次数为10 + let mut group = c.benchmark_group("naive_ppr_parametrization"); + group.sample_size(10); // 设置采样次数为10 + group.sampling_mode(SamplingMode::Flat); + + // 测试不同迭代次数 + for iterations in [5, 10, 15, 20, 50].iter() { + group.bench_function(format!("iterations_{}", iterations), |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(0.15_f64), + black_box(source_bias.clone()), + black_box(*iterations), + ); + black_box(result); + }); + }); + } + + // 测试不同阻尼因子 + for damping in [0.1, 0.3, 0.5, 0.7, 0.9].iter() { + group.bench_function(format!("damping_{}", damping), |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(*damping as f64), + black_box(source_bias.clone()), + black_box(15), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_naive_ppr_function, + bench_naive_ppr_parameterized +); +criterion_main!(benches); diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs index 9939f69..b696c0f 100644 --- a/src/utils/graph_algo.rs +++ b/src/utils/graph_algo.rs @@ -1,314 +1,2 @@ -use std::{collections::HashMap, fmt::Debug, hash::Hash}; - -use petgraph::{ - algo::UnitMeasure, - visit::{EdgeRef, IntoEdges, IntoNodeIdentifiers, NodeCount, NodeIndexable}, -}; - -///PPR: ppr_s = dampling_factor * P * ppr_s + (1-damping_factor) * personalized_vec, P为转移矩阵 -/// 对无出度的节点,采取与source_bias中的节点建立连接 -/// 必须保证source_bias的key是有效的NodeId, 否则会得到不正确的结果 -// 由于NodeId会由MemoryCluster提供,这不会造成额外的检查负担 -#[track_caller] -pub fn naive_ppr( - graph: G, - damping_factor: D, - personalized_vec: HashMap, - nb_iter: usize, -) -> HashMap -where - G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, - D: UnitMeasure + Copy, - G::NodeId: Hash + Eq, -{ - let node_count = graph.node_count(); - if node_count == 0 { - return HashMap::new(); - } - - //检查阻尼系数 - assert!( - D::zero() <= damping_factor && damping_factor <= D::one(), - "Damping factor should be between 0 et 1." - ); - - //检查个性化分布是不是一个概率分布 - let personalized_sum: D = personalized_vec.values().copied().sum(); - assert!( - personalized_sum > D::zero(), - "Personalized Source bias sum must be positive" - ); - - //归一化个性化向量(初始向量) - let normalized_personalized_vec: HashMap = if personalized_sum != D::one() { - personalized_vec - .into_iter() - .map(|(node_id, bias)| (node_id, bias / personalized_sum)) - .collect() - } else { - personalized_vec - }; - - //图中有效的索引值,适配StableGraph(索引可能不连续) - let valid_index = graph - .node_identifiers() - .map(|node_id| graph.to_index(node_id)) - .collect::>(); - - //ppr值的存储 - //此处可能有大量内存浪费(无效的索引值占位),考虑到工作记忆子图不会过于频繁释放和加载,这个内存开销应该是可以接受的 - let mut ppr_ranks = vec![D::zero(); graph.node_bound()]; - let mut out_degrees = vec![D::zero(); graph.node_bound()]; - - //使用个性化向量,初始化PPR值向量,由于源节点有向量相似性取top-k提供(k通常不大),这样初始化通常可以加快收敛速度 - normalized_personalized_vec - .iter() - .for_each(|(&node_id, &bias)| { - ppr_ranks[graph.to_index(node_id)] = bias; //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 - }); - let normalized_bias_len = normalized_personalized_vec.len(); - //println!("normalized_bias: {:?}", normalized_bias); - - //预计算每个节点的出度 - graph.node_identifiers().for_each(|node_id| { - out_degrees[graph.to_index(node_id)] = graph.edges(node_id).map(|_| D::one()).sum(); - }); - //println!("out_degrees: {:?}", out_degrees); - - for i in 0..nb_iter { - let ppr_vec_i = valid_index - .iter() - .map(|&computing_idx| { - let iter_ppr = valid_index - .iter() - .map(|&idx| { - //找到每个节点的出边 - let mut out_edges = graph.edges(graph.from_index(idx)); - - //游走部分的计算,对于无出度节点,默认其连接至所有个性化向量中不为0的节点 - if out_edges.any(|e| e.target() == graph.from_index(computing_idx)) { - damping_factor * ppr_ranks[idx] / out_degrees[idx] - } else if out_degrees[idx] == D::zero() { - normalized_personalized_vec - .get(&graph.from_index(computing_idx)) - .map(|_| { - damping_factor * ppr_ranks[idx] - / D::from_usize(normalized_bias_len) - }) - .unwrap_or(D::zero()) - } else { - D::zero() - } - }) - .sum::(); - - //随机重启部分计算 - let random_back_part = if let Some(per_i) = - normalized_personalized_vec.get(&graph.from_index(computing_idx)) - { - (D::one() - damping_factor) * *per_i - } else { - D::zero() - }; - - (computing_idx, iter_ppr + random_back_part) - }) - .collect::>(); - - // 归一化PPR值,确保数值稳定,总和为1 - let sum = ppr_vec_i.iter().map(|(_, ppr)| *ppr).sum::(); - - ppr_vec_i.iter().for_each(|&(idx, ppr)| { - ppr_ranks[idx] = ppr / sum; - }); - //println!("iteration {i}: PPR values: {:?}", ppr_ranks); - } - - //最终归一化 - let sum = ppr_ranks.iter().map(|ppr| *ppr).sum::(); - - //返回PPR向量,HashMap形式 - graph - .node_identifiers() - .map(|node_id| (node_id, ppr_ranks[graph.to_index(node_id)] / sum)) - .collect() -} - -#[cfg(test)] -mod test { - - use mockall::predicate::float; - use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph}; - - use super::*; - fn diff(actual: f64, expected: f64) -> f64 { - if expected.abs() < f64::EPSILON && actual.abs() < f64::EPSILON { - 0.0 - } else { - let diff = (actual - expected).abs(); - diff - } - } - fn pressure_large_graph() -> (StableDiGraph, Vec>) { - let mut graph = StableDiGraph::new(); - let mut nodes = Vec::new(); - for i in 0..5000 { - let mut node = graph.add_node("".to_string()); - if i % 2 == 0 || i % 7 == 0 { - graph.remove_node(node); - node = graph.add_node("".to_string()); - } - nodes.push(node); - graph.add_edge(node, node, 1.0); - nodes.iter().for_each(|idx| { - graph.add_edge(node, *idx, 1.0); - }); - } - (graph, nodes) - } - - fn test_toy_graph() -> (StableDiGraph, Vec>) { - let mut graph = StableDiGraph::new(); - let a = graph.add_node("A".to_string()); - let b = graph.add_node("B".to_string()); - //制造索引空洞 - graph.remove_node(b); - let b = graph.add_node("B".to_string()); - let c = graph.add_node("C".to_string()); - let d = graph.add_node("D".to_string()); - - graph.add_edge(a, b, 1.0); - graph.add_edge(a, c, 1.0); - graph.add_edge(b, c, 1.0); - graph.add_edge(c, d, 1.0); - - (graph, vec![a, b, c, d]) - } - fn toy_graph_with_init_a() -> ( - StableDiGraph, - HashMap, f64>, - Vec>, - ) { - let (graph, indexes) = test_toy_graph(); - let ans_vec: Vec = vec![0.851652742, 0.06387396045, 0.07345504972, 0.01101824785]; - let ans = indexes.iter().copied().zip(ans_vec).collect(); - (graph, ans, indexes) - } - fn toy_graph_with_init_b() -> ( - StableDiGraph, - HashMap, f64>, - Vec>, - ) { - let (graph, indexes) = test_toy_graph(); - let ans_vec: Vec = vec![0.0, 0.852878432, 0.1279320211, 0.01918954688]; - let ans = indexes.iter().copied().zip(ans_vec).collect(); - (graph, ans, indexes) - } - fn toy_graph_with_init_ab() -> ( - StableDiGraph, - HashMap, f64>, - Vec>, - ) { - let (graph, indexes) = test_toy_graph(); - let ans_vec: Vec = vec![0.4261326137, 0.4580925718, 0.1006738318, 0.00510098267]; - let ans = indexes.iter().copied().zip(ans_vec).collect(); - (graph, ans, indexes) - } - #[test] - fn ppr_toy_graph_init_a() { - let (graph, true_ans, indexes) = toy_graph_with_init_a(); - let mut source_bias = HashMap::new(); - source_bias.insert(indexes[0], 1.0); - - let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); - let ans_sum = ppr_ans.values().copied().sum::(); - assert!(ans_sum - 1.0 < f64::EPSILON); - - let avg_diff = 0.25 - * indexes - .iter() - .map(|idx| { - let actual = ppr_ans[idx]; - let expected = true_ans[idx]; - diff(actual, expected) - }) - .sum::(); - - assert!( - avg_diff < 0.005, - "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", - avg_diff, - ppr_ans, - true_ans - ) - } - #[test] - fn ppr_toy_graph_init_b() { - let (graph, true_ans, indexes) = toy_graph_with_init_b(); - let mut source_bias = HashMap::new(); - source_bias.insert(indexes[1], 1.0); - - let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); - let ans_sum = ppr_ans.values().copied().sum::(); - assert!(ans_sum - 1.0 < f64::EPSILON); - - let avg_diff = 0.25 - * indexes - .iter() - .map(|idx| { - let actual = ppr_ans[idx]; - let expected = true_ans[idx]; - diff(actual, expected) - }) - .sum::(); - - assert!( - avg_diff < 0.005, - "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", - avg_diff, - ppr_ans, - true_ans - ) - } - #[test] - fn ppr_toy_graph_init_ab() { - let (graph, true_ans, indexes) = toy_graph_with_init_ab(); - let mut source_bias = HashMap::new(); - source_bias.insert(indexes[0], 1.0); - source_bias.insert(indexes[1], 1.0); - - let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); - let ans_sum = ppr_ans.values().copied().sum::(); - assert!(ans_sum - 1.0 < f64::EPSILON); - - let avg_diff = 0.25 - * indexes - .iter() - .map(|idx| { - let actual = ppr_ans[idx]; - let expected = true_ans[idx]; - diff(actual, expected) - }) - .sum::(); - - assert!( - avg_diff < 0.005, - "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", - avg_diff, - ppr_ans, - true_ans - ) - } - #[test] - fn pressure_large_graph_test() { - let (graph, nodes) = pressure_large_graph(); - let mut source_bias = HashMap::new(); - nodes.iter().take(10).for_each(|idx| { - source_bias.insert(*idx, graph.to_index(*idx) as f64); - }); - - let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); - let ans_sum = ppr_ans.values().copied().sum::(); - assert!(ans_sum - 1.0 < f64::EPSILON); - } -} +pub mod ord_float; +pub mod ppr; From 6494d7c622a3ea58c4a7743cc834e173cfe26868 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 18 Dec 2025 17:33:25 +0800 Subject: [PATCH 18/22] Update Cargo.lock --- Cargo.lock | 145 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 97e7ae3..4088e5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,15 @@ dependencies = [ "equator", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -125,6 +134,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstyle" version = "1.0.11" @@ -731,6 +746,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "castaway" version = "0.2.3" @@ -902,6 +923,31 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" + [[package]] name = "cmsketch" version = "0.2.4" @@ -1012,6 +1058,41 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" +dependencies = [ + "alloca", + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "itertools 0.13.0", + "num-traits", + "oorandom", + "page_size", + "plotters", + "rayon", + "regex", + "serde", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" +dependencies = [ + "cast", + "itertools 0.13.0", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -2666,6 +2747,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -3559,6 +3649,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openssl" version = "0.10.73" @@ -3643,6 +3739,16 @@ dependencies = [ "ureq 3.1.2", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "panic-message" version = "0.3.0" @@ -3904,6 +4010,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.16" @@ -5274,6 +5408,7 @@ dependencies = [ "approx 0.5.1", "async-trait", "chrono", + "criterion", "dotenvy", "fastembed", "formatx", @@ -5839,6 +5974,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.9.0" From b272c1ec208a74d6681a6c11d231907b8b47b73d Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 18 Dec 2025 17:33:51 +0800 Subject: [PATCH 19/22] refactor graph_algo, sketching weighted_ppr_fp --- src/utils/graph_algo/ord_float.rs | 86 ++++++ src/utils/graph_algo/ppr.rs | 484 ++++++++++++++++++++++++++++++ 2 files changed, 570 insertions(+) create mode 100644 src/utils/graph_algo/ord_float.rs create mode 100644 src/utils/graph_algo/ppr.rs diff --git a/src/utils/graph_algo/ord_float.rs b/src/utils/graph_algo/ord_float.rs new file mode 100644 index 0000000..759b935 --- /dev/null +++ b/src/utils/graph_algo/ord_float.rs @@ -0,0 +1,86 @@ +use core::iter::Sum; +use ordered_float::{FloatCore, OrderedFloat, PrimitiveFloat}; +use petgraph::algo::UnitMeasure; +use std::fmt::Debug; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct OrdFloat(OrderedFloat); +impl Default for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + Default, +{ + fn default() -> Self { + OrdFloat(OrderedFloat::default()) + } +} +impl Sum for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + Sum, +{ + fn sum>(iter: I) -> Self { + OrdFloat(OrderedFloat(iter.map(|f| f.0.0).sum())) + } +} +impl std::ops::Add for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Add, +{ + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 + rhs.0) + } +} +impl std::ops::Sub for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Sub, +{ + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 - rhs.0) + } +} +impl std::ops::Mul for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Mul, +{ + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 * rhs.0) + } +} +impl std::ops::Div for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Div, +{ + type Output = Self; + fn div(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 / rhs.0) + } +} + +impl UnitMeasure for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + Debug + Sum + Default, +{ + fn default_tol() -> Self { + OrdFloat(OrderedFloat((F::epsilon()))) + } + ///如果F不能从f32转化,则变为NaN而非panic + fn from_f32(val: f32) -> Self { + OrdFloat(OrderedFloat(F::from(val).unwrap_or(F::nan()))) + } + ///如果F不能从f64转化,则变为NaN而非panic + fn from_f64(val: f64) -> Self { + OrdFloat(OrderedFloat(F::from(val).unwrap_or(F::nan()))) + } + ///如果F不能从usize转化,则变为NaN而非panic + fn from_usize(nb: usize) -> Self { + OrdFloat(OrderedFloat(F::from(nb).unwrap_or(F::nan()))) + } + fn one() -> Self { + OrdFloat(OrderedFloat(F::one())) + } + fn zero() -> Self { + OrdFloat(OrderedFloat(F::zero())) + } +} diff --git a/src/utils/graph_algo/ppr.rs b/src/utils/graph_algo/ppr.rs new file mode 100644 index 0000000..5715caf --- /dev/null +++ b/src/utils/graph_algo/ppr.rs @@ -0,0 +1,484 @@ +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashMap}, + fmt::Debug, + hash::Hash, +}; + +use petgraph::{ + Direction::Incoming, + algo::UnitMeasure, + visit::{EdgeRef, IntoEdges, IntoNodeIdentifiers, NodeCount, NodeIndexable}, +}; + +///PPR: ppr_s = dampling_factor * P * ppr_s + (1-damping_factor) * personalized_vec, P为转移矩阵 +/// 对无出度的节点,采取与source_bias中的节点建立连接 +/// 必须保证source_bias的key是有效的NodeId, 否则会得到不正确的结果 +// 由于NodeId会由MemoryCluster提供,这不会造成额外的检查负担 +#[track_caller] +pub fn naive_ppr( + graph: G, + damping_factor: D, + personalized_vec: HashMap, + nb_iter: usize, +) -> HashMap +where + G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, + D: UnitMeasure + Copy, + G::NodeId: Hash + Eq, +{ + let node_count = graph.node_count(); + if node_count == 0 { + return HashMap::new(); + } + + //检查阻尼系数 + assert!( + D::zero() <= damping_factor && damping_factor <= D::one(), + "Damping factor should be between 0 et 1." + ); + + //检查个性化分布是不是一个概率分布 + let personalized_sum: D = personalized_vec.values().copied().sum(); + assert!( + personalized_sum > D::zero(), + "Personalized Source bias sum must be positive" + ); + + //归一化个性化向量(初始向量) + let normalized_personalized_vec: HashMap = if personalized_sum != D::one() { + personalized_vec + .into_iter() + .map(|(node_id, bias)| (node_id, bias / personalized_sum)) + .collect() + } else { + personalized_vec + }; + + //图中有效的索引值,适配StableGraph(索引可能不连续) + let valid_index = graph + .node_identifiers() + .map(|node_id| graph.to_index(node_id)) + .collect::>(); + + //ppr值的存储 + //此处可能有大量内存浪费(无效的索引值占位),考虑到工作记忆子图不会过于频繁释放和加载,这个内存开销应该是可以接受的 + let mut ppr_ranks = vec![D::zero(); graph.node_bound()]; + let mut out_degrees = vec![D::zero(); graph.node_bound()]; + + //使用个性化向量,初始化PPR值向量,由于源节点有向量相似性取top-k提供(k通常不大),这样初始化通常可以加快收敛速度 + normalized_personalized_vec + .iter() + .for_each(|(&node_id, &bias)| { + ppr_ranks[graph.to_index(node_id)] = bias; //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 + }); + let normalized_bias_len = normalized_personalized_vec.len(); + //println!("normalized_bias: {:?}", normalized_bias); + + //预计算每个节点的出度 + graph.node_identifiers().for_each(|node_id| { + out_degrees[graph.to_index(node_id)] = graph.edges(node_id).map(|_| D::one()).sum(); + }); + //println!("out_degrees: {:?}", out_degrees); + + for i in 0..nb_iter { + let ppr_vec_i = valid_index + .iter() + .map(|&computing_idx| { + let iter_ppr = valid_index + .iter() + .map(|&idx| { + //找到每个节点的出边 + let mut out_edges = graph.edges(graph.from_index(idx)); + + //游走部分的计算,对于无出度节点,默认其连接至所有个性化向量中不为0的节点 + if out_edges.any(|e| e.target() == graph.from_index(computing_idx)) { + damping_factor * ppr_ranks[idx] / out_degrees[idx] + } else if out_degrees[idx] == D::zero() { + normalized_personalized_vec + .get(&graph.from_index(computing_idx)) + .map(|_| { + damping_factor * ppr_ranks[idx] + / D::from_usize(normalized_bias_len) + }) + .unwrap_or(D::zero()) + } else { + D::zero() + } + }) + .sum::(); + + //随机重启部分计算 + let random_back_part = if let Some(per_i) = + normalized_personalized_vec.get(&graph.from_index(computing_idx)) + { + (D::one() - damping_factor) * *per_i + } else { + D::zero() + }; + + (computing_idx, iter_ppr + random_back_part) + }) + .collect::>(); + + // 归一化PPR值,确保数值稳定,总和为1 + let sum = ppr_vec_i.iter().map(|(_, ppr)| *ppr).sum::(); + + ppr_vec_i.iter().for_each(|&(idx, ppr)| { + ppr_ranks[idx] = ppr / sum; + }); + //println!("iteration {i}: PPR values: {:?}", ppr_ranks); + } + + //最终归一化 + let sum = ppr_ranks.iter().map(|ppr| *ppr).sum::(); + + //返回PPR向量,HashMap形式 + graph + .node_identifiers() + .map(|node_id| (node_id, ppr_ranks[graph.to_index(node_id)] / sum)) + .collect() +} + +//用于BinaryHeap的残差单元表示 +struct ResidueUnit { + pub idx: Index, + pub value: D_R, +} +impl PartialOrd for ResidueUnit +where + D_R: UnitMeasure + Copy + PartialOrd, +{ + fn ge(&self, other: &Self) -> bool { + self.value.ge(&other.value) + } + fn gt(&self, other: &Self) -> bool { + self.value.gt(&other.value) + } + fn le(&self, other: &Self) -> bool { + self.value.le(&other.value) + } + fn lt(&self, other: &Self) -> bool { + self.value.lt(&other.value) + } + fn partial_cmp(&self, other: &Self) -> Option { + self.value.partial_cmp(&other.value) + } +} +impl PartialEq for ResidueUnit +where + D_R: UnitMeasure + Copy + PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } + fn ne(&self, other: &Self) -> bool { + self.value.ne(&other.value) + } +} +impl Eq for ResidueUnit where D_R: UnitMeasure + Copy + Eq {} +impl Ord for ResidueUnit +where + D_R: UnitMeasure + Copy + Ord + PartialOrd + Eq, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.value.cmp(&other.value) + } + fn clamp(self, min: Self, max: Self) -> Self + where + Self: Sized, + { + Self { + idx: self.idx, + value: self.value.clamp(min.value, max.value), + } + } + fn max(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => other, + Ordering::Equal => other, + Ordering::Greater => self, + } + } + fn min(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => self, + Ordering::Equal => self, + Ordering::Greater => other, + } + } +} +type EdgeWeightUnit = ResidueUnit; + +#[track_caller] +pub fn weighted_ppr_fp( + graph: G, + damping_factor: D, + personalized_vec: HashMap, + residue_threshold: D, + weight_calc: impl Fn(&G::EdgeRef, &Q) -> D, + dynamic_query: &Q, +) -> HashMap +where + G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, + D: UnitMeasure + Copy + Ord, + G::NodeId: Hash + Eq, + G::EdgeId: Hash + Eq, +{ + let personalized_sum = personalized_vec.values().copied().sum::(); + assert!( + personalized_sum > D::zero(), + "Personalized Source bias sum must be positive" + ); + + let normalized_personalized_vec: HashMap = if personalized_sum != D::one() { + personalized_vec + .into_iter() + .map(|(node_id, bias)| (node_id, bias / personalized_sum)) + .collect() + } else { + personalized_vec + }; + let source_node_count = D::from_usize(normalized_personalized_vec.len()); + + let mut reserve_vec = vec![D::zero(); graph.node_bound()]; + let mut residue_vec = (0..graph.node_bound()) + .map(|i| { + let residue_i = normalized_personalized_vec + .get(&graph.from_index(i)) + .copied() + .unwrap_or(D::zero()); + ResidueUnit { + idx: i, + value: residue_i, + } + }) + .collect::>(); + + let mut ppr_edge_weight_cache: HashMap>> = + HashMap::with_capacity(graph.node_count()); + + while let Some(residue_i) = residue_vec.pop() { + let out_edges = graph.edges(graph.from_index(residue_i.idx)); + //动态归一化的边权计算 + if !ppr_edge_weight_cache.contains_key(&graph.from_index(residue_i.idx)) { + let weights = out_edges + .map(|edge| { + let weight = weight_calc(&edge, dynamic_query); + EdgeWeightUnit { + idx: edge.id(), + value: weight, + } + }) + .collect::>(); + let sum = weights.iter().map(|v| v.value).sum::(); + let weights = weights + .into_iter() + .map(|w| EdgeWeightUnit { + idx: w.id, + value: w.value / sum, + }) + .collect::>(); + ppr_edge_weight_cache.insert(graph.from_index(residue_i.idx), weights); + } + + let edge_weights = &ppr_edge_weight_cache[&graph.from_index(residue_i.idx)]; + + //残差push + if let Some(edge_weight_max) = edge_weights.iter().max() { + if residue_i.value * edge_weight_max.value > residue_threshold { + todo!("push residue by weights") + } + } else { + if residue_i.value / source_node_count > residue_threshold { + todo!("push residue by weights") + } + } + } + todo!() +} + +#[cfg(test)] +mod test { + + use mockall::predicate::float; + use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph}; + + use super::*; + fn diff(actual: f64, expected: f64) -> f64 { + if expected.abs() < f64::EPSILON && actual.abs() < f64::EPSILON { + 0.0 + } else { + let diff = (actual - expected).abs(); + diff + } + } + fn pressure_large_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + for i in 0..500 { + let mut node = graph.add_node("".to_string()); + if i % 2 == 0 || i % 7 == 0 { + graph.remove_node(node); + node = graph.add_node("".to_string()); + } + nodes.push(node); + graph.add_edge(node, node, 1.0); + nodes.iter().for_each(|idx| { + graph.add_edge(node, *idx, 1.0); + }); + } + (graph, nodes) + } + + fn test_toy_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let a = graph.add_node("A".to_string()); + let b = graph.add_node("B".to_string()); + //制造索引空洞 + graph.remove_node(b); + let b = graph.add_node("B".to_string()); + let c = graph.add_node("C".to_string()); + let d = graph.add_node("D".to_string()); + + graph.add_edge(a, b, 1.0); + graph.add_edge(a, c, 1.0); + graph.add_edge(b, c, 1.0); + graph.add_edge(c, d, 1.0); + + (graph, vec![a, b, c, d]) + } + fn toy_graph_with_init_a() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.851652742, 0.06387396045, 0.07345504972, 0.01101824785]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + fn toy_graph_with_init_b() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.0, 0.852878432, 0.1279320211, 0.01918954688]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + fn toy_graph_with_init_ab() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.4261326137, 0.4580925718, 0.1006738318, 0.00510098267]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + #[test] + fn ppr_toy_graph_init_a() { + let (graph, true_ans, indexes) = toy_graph_with_init_a(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_toy_graph_init_b() { + let (graph, true_ans, indexes) = toy_graph_with_init_b(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[1], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_toy_graph_init_ab() { + let (graph, true_ans, indexes) = toy_graph_with_init_ab(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], 1.0); + source_bias.insert(indexes[1], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn pressure_large_graph_test() { + let (graph, nodes) = pressure_large_graph(); + let mut source_bias = HashMap::new(); + nodes.iter().take(10).for_each(|idx| { + source_bias.insert(*idx, graph.to_index(*idx) as f64); + }); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + } +} From 0371a99d8a46454305309e338126fa5ab478ff76 Mon Sep 17 00:00:00 2001 From: dynamder Date: Thu, 18 Dec 2025 17:34:33 +0800 Subject: [PATCH 20/22] fix typo --- src/utils/graph_algo/ppr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/graph_algo/ppr.rs b/src/utils/graph_algo/ppr.rs index 5715caf..6cc129f 100644 --- a/src/utils/graph_algo/ppr.rs +++ b/src/utils/graph_algo/ppr.rs @@ -281,7 +281,7 @@ where let weights = weights .into_iter() .map(|w| EdgeWeightUnit { - idx: w.id, + idx: w.idx, value: w.value / sum, }) .collect::>(); From df9fbb5aa1ca4e2bb65cb92247bffb5c2f6d6677 Mon Sep 17 00:00:00 2001 From: dynamder Date: Wed, 31 Dec 2025 18:22:02 +0800 Subject: [PATCH 21/22] implemented weighted_ppr_fp, some test failed. --- src/utils/graph_algo/ord_float.rs | 90 ++++++++++- src/utils/graph_algo/ppr.rs | 238 ++++++++++++++++++++++++++++-- 2 files changed, 312 insertions(+), 16 deletions(-) diff --git a/src/utils/graph_algo/ord_float.rs b/src/utils/graph_algo/ord_float.rs index 759b935..4807a93 100644 --- a/src/utils/graph_algo/ord_float.rs +++ b/src/utils/graph_algo/ord_float.rs @@ -3,8 +3,16 @@ use ordered_float::{FloatCore, OrderedFloat, PrimitiveFloat}; use petgraph::algo::UnitMeasure; use std::fmt::Debug; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Hash)] pub struct OrdFloat(OrderedFloat); +impl OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + pub fn into_inner(self) -> F { + self.0.into_inner() + } +} impl Default for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat + Default, @@ -30,6 +38,14 @@ where OrdFloat(self.0 + rhs.0) } } +impl std::ops::AddAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::AddAssign, +{ + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} impl std::ops::Sub for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Sub, @@ -39,6 +55,14 @@ where OrdFloat(self.0 - rhs.0) } } +impl std::ops::SubAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::SubAssign, +{ + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + } +} impl std::ops::Mul for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Mul, @@ -48,6 +72,14 @@ where OrdFloat(self.0 * rhs.0) } } +impl std::ops::MulAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::MulAssign, +{ + fn mul_assign(&mut self, rhs: Self) { + self.0 *= rhs.0; + } +} impl std::ops::Div for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Div, @@ -57,7 +89,23 @@ where OrdFloat(self.0 / rhs.0) } } - +impl std::ops::DivAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::DivAssign, +{ + fn div_assign(&mut self, rhs: Self) { + self.0 /= rhs.0; + } +} +impl std::ops::Neg for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Neg, +{ + type Output = Self; + fn neg(self) -> Self::Output { + OrdFloat(-self.0) + } +} impl UnitMeasure for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat + Debug + Sum + Default, @@ -84,3 +132,41 @@ where OrdFloat(OrderedFloat(F::zero())) } } +impl From for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + fn from(val: F) -> Self { + OrdFloat(OrderedFloat(val)) + } +} +impl Eq for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat {} + +impl PartialOrd for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + fn ge(&self, other: &Self) -> bool { + self.0.ge(&other.0) + } + fn gt(&self, other: &Self) -> bool { + self.0.gt(&other.0) + } + fn le(&self, other: &Self) -> bool { + self.0.le(&other.0) + } + fn lt(&self, other: &Self) -> bool { + self.0.lt(&other.0) + } + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} +impl Ord for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} diff --git a/src/utils/graph_algo/ppr.rs b/src/utils/graph_algo/ppr.rs index 6cc129f..f793be2 100644 --- a/src/utils/graph_algo/ppr.rs +++ b/src/utils/graph_algo/ppr.rs @@ -3,8 +3,10 @@ use std::{ collections::{BinaryHeap, HashMap}, fmt::Debug, hash::Hash, + ops::AddAssign, }; +use super::ord_float::OrdFloat; use petgraph::{ Direction::Incoming, algo::UnitMeasure, @@ -140,13 +142,15 @@ where .collect() } -//用于BinaryHeap的残差单元表示 -struct ResidueUnit { +//残差单元表示 +#[derive(Debug, Clone, Copy)] +struct ResidueUnit { pub idx: Index, pub value: D_R, } impl PartialOrd for ResidueUnit where + Index: Copy, D_R: UnitMeasure + Copy + PartialOrd, { fn ge(&self, other: &Self) -> bool { @@ -167,6 +171,7 @@ where } impl PartialEq for ResidueUnit where + Index: Copy, D_R: UnitMeasure + Copy + PartialEq, { fn eq(&self, other: &Self) -> bool { @@ -176,9 +181,15 @@ where self.value.ne(&other.value) } } -impl Eq for ResidueUnit where D_R: UnitMeasure + Copy + Eq {} +impl Eq for ResidueUnit +where + Index: Copy, + D_R: UnitMeasure + Copy + Eq, +{ +} impl Ord for ResidueUnit where + Index: Copy, D_R: UnitMeasure + Copy + Ord + PartialOrd + Eq, { fn cmp(&self, other: &Self) -> std::cmp::Ordering { @@ -214,9 +225,92 @@ where } } } -type EdgeWeightUnit = ResidueUnit; + +//边权的单元表示 +#[derive(Debug)] +pub struct EdgeWeightUnit +where + D: UnitMeasure + Copy, +{ + pub target_node: NodeIdx, + pub idx: EdgeIdx, + pub value: D, +} +impl PartialOrd for EdgeWeightUnit +where + D: UnitMeasure + Copy + PartialOrd, +{ + fn ge(&self, other: &Self) -> bool { + self.value.ge(&other.value) + } + fn gt(&self, other: &Self) -> bool { + self.value.gt(&other.value) + } + fn le(&self, other: &Self) -> bool { + self.value.le(&other.value) + } + fn lt(&self, other: &Self) -> bool { + self.value.lt(&other.value) + } + fn partial_cmp(&self, other: &Self) -> Option { + self.value.partial_cmp(&other.value) + } +} +impl PartialEq for EdgeWeightUnit +where + D: UnitMeasure + Copy + PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } + fn ne(&self, other: &Self) -> bool { + self.value.ne(&other.value) + } +} +impl Eq for EdgeWeightUnit where D: UnitMeasure + Copy + Eq +{} + +impl Ord for EdgeWeightUnit +where + D: UnitMeasure + Copy + Ord + PartialOrd + Eq, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.value.cmp(&other.value) + } + fn clamp(self, min: Self, max: Self) -> Self + where + Self: Sized, + { + Self { + target_node: self.target_node, + idx: self.idx, + value: self.value.clamp(min.value, max.value), + } + } + fn max(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => other, + Ordering::Equal => other, + Ordering::Greater => self, + } + } + fn min(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => self, + Ordering::Equal => self, + Ordering::Greater => other, + } + } +} #[track_caller] +//TODO: make damping factor specific to each node pub fn weighted_ppr_fp( graph: G, damping_factor: D, @@ -227,10 +321,11 @@ pub fn weighted_ppr_fp( ) -> HashMap where G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, - D: UnitMeasure + Copy + Ord, - G::NodeId: Hash + Eq, - G::EdgeId: Hash + Eq, + D: UnitMeasure + Copy + AddAssign + Ord, + G::NodeId: Hash + Eq + Debug, //TODO: delete the Debug Trait bound + G::EdgeId: Hash + Eq + Debug, { + //归一化个性化向量 let personalized_sum = personalized_vec.values().copied().sum::(); assert!( personalized_sum > D::zero(), @@ -245,8 +340,11 @@ where } else { personalized_vec }; + let source_node_count = D::from_usize(normalized_personalized_vec.len()); + println!("source_node_count: {:?}", source_node_count); + //初始化残差和保留 let mut reserve_vec = vec![D::zero(); graph.node_bound()]; let mut residue_vec = (0..graph.node_bound()) .map(|i| { @@ -259,19 +357,25 @@ where value: residue_i, } }) - .collect::>(); + .collect::>(); - let mut ppr_edge_weight_cache: HashMap>> = - HashMap::with_capacity(graph.node_count()); + let mut ppr_edge_weight_cache: HashMap< + G::NodeId, + Vec>, + > = HashMap::with_capacity(graph.node_count()); - while let Some(residue_i) = residue_vec.pop() { + //每次取残差最大的节点进行push,加速收敛 + while let Some(residue_i) = residue_vec.iter().copied().max() { + println!("Processing node {}", residue_i.idx); let out_edges = graph.edges(graph.from_index(residue_i.idx)); //动态归一化的边权计算 if !ppr_edge_weight_cache.contains_key(&graph.from_index(residue_i.idx)) { + println!("Calculating edge weights for node {}", residue_i.idx); let weights = out_edges .map(|edge| { let weight = weight_calc(&edge, dynamic_query); EdgeWeightUnit { + target_node: edge.target(), idx: edge.id(), value: weight, } @@ -281,6 +385,7 @@ where let weights = weights .into_iter() .map(|w| EdgeWeightUnit { + target_node: w.target_node, idx: w.idx, value: w.value / sum, }) @@ -289,19 +394,45 @@ where } let edge_weights = &ppr_edge_weight_cache[&graph.from_index(residue_i.idx)]; + println!("edge_weights: {:?}", edge_weights); + //清空当前节点残差 + residue_vec[residue_i.idx].value = D::zero(); + + //将部分残差转为保留 + reserve_vec[residue_i.idx] += (D::one() - damping_factor) * residue_i.value; //残差push if let Some(edge_weight_max) = edge_weights.iter().max() { + //节点出度不为0的情况 if residue_i.value * edge_weight_max.value > residue_threshold { - todo!("push residue by weights") + edge_weights.iter().for_each(|edge_w| { + residue_vec[graph.to_index(edge_w.target_node)].value += + damping_factor * edge_w.value * residue_i.value; + }); + } else { + break; } } else { + //节点出度为0的情况 if residue_i.value / source_node_count > residue_threshold { - todo!("push residue by weights") + normalized_personalized_vec.keys().for_each(|node| { + residue_vec[graph.to_index(*node)].value += + damping_factor * residue_i.value / source_node_count; + }); + } else { + break; } } } - todo!() + let sum = reserve_vec.iter().copied().sum::(); + + graph + .node_identifiers() + .map(|node| { + let ppr_value = reserve_vec[graph.to_index(node)] / sum; + (node, ppr_value) + }) + .collect() } #[cfg(test)] @@ -470,6 +601,85 @@ mod test { ) } #[test] + fn ppr_forward_push_toy_graph_init_a() { + let (graph, true_ans, indexes) = toy_graph_with_init_a(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], OrdFloat::from_f64(1.0)); + + let ppr_ans = weighted_ppr_fp( + &graph, + OrdFloat::from_f64(0.15), + source_bias, + OrdFloat::from_f64(0.002), + |_, _| OrdFloat::from_f64(1.0), + &"1", + ); + let ans_sum: f64 = ppr_ans + .values() + .copied() + .sum::>() + .into_inner(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual: f64 = ppr_ans[idx].into_inner(); + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + //TODO: pass this test + fn ppr_forward_push_toy_graph_init_b() { + let (graph, true_ans, indexes) = toy_graph_with_init_b(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], OrdFloat::from_f64(1.0)); + + let ppr_ans = weighted_ppr_fp( + &graph, + OrdFloat::from_f64(0.15), + source_bias, + OrdFloat::from_f64(0.002), + |_, _| OrdFloat::from_f64(1.0), + &"1", + ); + let ans_sum: f64 = ppr_ans + .values() + .copied() + .sum::>() + .into_inner(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual: f64 = ppr_ans[idx].into_inner(); + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] fn pressure_large_graph_test() { let (graph, nodes) = pressure_large_graph(); let mut source_bias = HashMap::new(); From 42b7216a9a0e4ba431759b8e5959e3e391a690d4 Mon Sep 17 00:00:00 2001 From: dynamder Date: Sat, 17 Jan 2026 01:16:57 +0800 Subject: [PATCH 22/22] implemented forward push algorithm with dynamic edge weight, pass simple tests. --- benches/ppr.rs | 348 +++++++++++++++++++++++++++--------- src/utils/graph_algo/ppr.rs | 52 +++++- 2 files changed, 308 insertions(+), 92 deletions(-) diff --git a/benches/ppr.rs b/benches/ppr.rs index 7ed243e..ef8af1a 100644 --- a/benches/ppr.rs +++ b/benches/ppr.rs @@ -1,55 +1,132 @@ -use criterion::{ - BenchmarkGroup, Criterion, SamplingMode, black_box, criterion_group, criterion_main, -}; -use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph, visit::NodeIndexable}; -use soul_mem::utils::graph_algo::ppr::naive_ppr; +//! PPR算法性能基准测试 +//! +//! 本文件包含两种PPR(Personalized PageRank)算法的性能对比测试: +//! 1. Power Iteration算法:传统迭代方法,收敛性好但较慢 +//! 2. Forward Push算法:近似算法,速度极快但有一定精度损失 +//! +//! 使用方法: +//! 1. 运行所有测试: cargo bench +//! 2. 运行特定测试: cargo bench <测试组名> +//! 3. 测试结果会生成在target/criterion目录下 + +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use petgraph::stable_graph::NodeIndex; +use petgraph::{prelude::StableDiGraph, visit::NodeIndexable}; +use soul_mem::utils::graph_algo::ord_float::OrdFloat; +use soul_mem::utils::graph_algo::ppr::{naive_ppr, weighted_ppr_fp}; use std::collections::HashMap; +use std::hint::black_box; -/// 创建压力测试图的辅助函数(在benchmark计时范围外使用) -fn pressure_large_graph() -> (StableDiGraph, Vec>) { +/// 创建小型测试图的辅助函数(在benchmark计时范围外使用) +/// 使用稀疏图结构以提高测试效率 +fn create_test_graph(size: usize) -> (StableDiGraph>, Vec>) { let mut graph = StableDiGraph::new(); let mut nodes = Vec::new(); - for i in 0..500 { - let mut node = graph.add_node("".to_string()); - if i % 2 == 0 || i % 7 == 0 { - graph.remove_node(node); - node = graph.add_node("".to_string()); - } + + // 创建指定大小的图 + for i in 0..size { + let node = graph.add_node(format!("node_{}", i)); nodes.push(node); - graph.add_edge(node, node, 1.0); - nodes.iter().for_each(|idx| { - graph.add_edge(node, *idx, 1.0); + + // 每个节点连接到前min(10, 节点数)个节点(如果存在) + nodes.iter().take(10.min(nodes.len())).for_each(|idx| { + graph.add_edge(node, *idx, OrdFloat::from(1.0)); }); + + // 添加自环 + graph.add_edge(node, node, OrdFloat::from(1.0)); + } + (graph, nodes) +} + +/// 创建大规模测试图的辅助函数(在benchmark计时范围外使用) +fn create_large_test_graph( + size: usize, +) -> (StableDiGraph>, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + + // 创建指定大小的图(优化版本,减少边数以保持可管理性) + for i in 0..size { + let node = graph.add_node(format!("node_{}", i)); + nodes.push(node); + + // 对于大规模图,限制边的数量以提高性能 + let max_connections = if size > 1000 { 5 } else { 10 }; + nodes + .iter() + .take(max_connections.min(nodes.len())) + .for_each(|idx| { + graph.add_edge(node, *idx, OrdFloat::from(1.0)); + }); + + // 添加自环 + graph.add_edge(node, node, OrdFloat::from(1.0)); } (graph, nodes) } -/// 准备测试数据(在benchmark外执行) -fn prepare_test_data() -> (StableDiGraph, HashMap, f64>) { - let (graph, nodes) = pressure_large_graph(); +/// 简单的边权重计算函数(用于forward push算法) +fn simple_weight_calc( + _edge: &petgraph::stable_graph::EdgeReference>, + _query: &(), +) -> OrdFloat { + OrdFloat::from(1.0) +} + +/// 准备小型测试数据(在benchmark外执行) +/// 用于基础性能对比测试 +fn prepare_test_data() -> ( + StableDiGraph>, + HashMap, OrdFloat>, +) { + let (graph, nodes) = create_test_graph(20); let mut source_bias = HashMap::new(); - nodes.iter().take(10).for_each(|idx| { - source_bias.insert(*idx, graph.to_index(*idx) as f64); + + // 前3个节点作为源节点 + nodes.iter().take(3).for_each(|idx| { + source_bias.insert(*idx, OrdFloat::from(graph.to_index(*idx) as f64)); }); + (graph, source_bias) } -/// 专门测试naive_ppr函数性能的benchmark -fn bench_naive_ppr_function(c: &mut Criterion) { - // 准备测试数据(不计时) +/// 准备大规模测试数据(在benchmark外执行) +/// 专门用于中等和大型规模图的性能测试 +fn prepare_large_test_data( + size: usize, +) -> ( + StableDiGraph>, + HashMap, OrdFloat>, +) { + let (graph, nodes) = create_large_test_graph(size); + let mut source_bias = HashMap::new(); + + // 前min(5, size/100)个节点作为源节点 + let source_count = (5.max(size / 100)).min(nodes.len()); + nodes.iter().take(source_count).for_each(|idx| { + source_bias.insert(*idx, OrdFloat::from(graph.to_index(*idx) as f64)); + }); + + (graph, source_bias) +} + +/// 基础性能对比:Power Iteration vs Forward Push +/// 测试20个节点的小型图,展示两种算法的基本性能差异 +/// 结果显示Forward Push比Power Iteration快约20-30倍 +fn bench_basic_comparison(c: &mut Criterion) { let (graph, source_bias) = prepare_test_data(); - // 设置采样次数为10 - let mut group = c.benchmark_group("naive_ppr_basic"); - group.sample_size(10); // 设置采样次数为10 + let mut group = c.benchmark_group("basic_comparison"); + group.sample_size(10); group.sampling_mode(SamplingMode::Flat); - // Benchmark 1: 基础性能测试 - group.bench_function("basic", |b| { + // Power Iteration算法 + group.bench_function("power_iteration_15_iters", |b| { b.iter(|| { let result = naive_ppr( black_box(&graph), - black_box(0.15_f64), + black_box(OrdFloat::from(0.15)), black_box(source_bias.clone()), black_box(15), ); @@ -57,94 +134,189 @@ fn bench_naive_ppr_function(c: &mut Criterion) { }); }); - // Benchmark 2: 测试不同阻尼因子 - group.bench_function("damping_high", |b| { + // Forward Push算法 + group.bench_function("forward_push_threshold_1e-4", |b| { b.iter(|| { - let result = naive_ppr( + let result = weighted_ppr_fp( black_box(&graph), - black_box(0.85_f64), // 高阻尼因子 + black_box(OrdFloat::from(0.15)), black_box(source_bias.clone()), - black_box(15), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), ); black_box(result); }); }); - // Benchmark 3: 测试不同迭代次数 - group.bench_function("iterations_50", |b| { - b.iter(|| { - let result = naive_ppr( - black_box(&graph), - black_box(0.15_f64), - black_box(source_bias.clone()), - black_box(50), // 更多迭代次数 - ); - black_box(result); - }); - }); + group.finish(); +} - // Benchmark 4: 测试更少的迭代次数 - group.bench_function("iterations_5", |b| { - b.iter(|| { - let result = naive_ppr( - black_box(&graph), - black_box(0.15_f64), - black_box(source_bias.clone()), - black_box(5), // 更少迭代次数 - ); - black_box(result); +/// 不同迭代次数下的Power Iteration性能 +/// 展示Power Iteration算法的时间复杂度与迭代次数的关系 +/// 执行时间与迭代次数呈近似线性增长 +fn bench_power_iteration_variants(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + let mut group = c.benchmark_group("power_iteration_variants"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + for iterations in [5, 10, 15, 20].iter() { + group.bench_function(format!("iterations_{}", iterations), |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(*iterations), + ); + black_box(result); + }); }); - }); + } - // Benchmark 5: 测试中等阻尼因子 - group.bench_function("damping_medium", |b| { - b.iter(|| { - let result = naive_ppr( - black_box(&graph), - black_box(0.5_f64), // 中等阻尼因子 - black_box(source_bias.clone()), - black_box(15), - ); - black_box(result); + group.finish(); +} + +/// 不同残差阈值下的Forward Push性能 +/// 展示阈值对Forward Push算法的精度和速度的影响 +/// 阈值越小,精度越高但执行时间也相应增加 +fn bench_forward_push_variants(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + let mut group = c.benchmark_group("forward_push_variants"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + for threshold in [0.001, 0.0001, 0.00001].iter() { + group.bench_function(format!("threshold_{}", threshold), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(*threshold)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); }); - }); + } group.finish(); } -/// 参数化benchmark组,测试不同参数组合 -fn bench_naive_ppr_parameterized(c: &mut Criterion) { +/// 不同阻尼因子下的性能对比 +/// 测试阻尼因子[0.1, 0.3, 0.5, 0.7]对两种算法的影响 +/// Forward Push在高阻尼因子下性能略有下降 +fn bench_damping_factors(c: &mut Criterion) { let (graph, source_bias) = prepare_test_data(); - // 设置采样次数为10 - let mut group = c.benchmark_group("naive_ppr_parametrization"); - group.sample_size(10); // 设置采样次数为10 + let mut group = c.benchmark_group("damping_factors"); + group.sample_size(10); group.sampling_mode(SamplingMode::Flat); - // 测试不同迭代次数 - for iterations in [5, 10, 15, 20, 50].iter() { - group.bench_function(format!("iterations_{}", iterations), |b| { + for damping in [0.1, 0.3, 0.5, 0.7].iter() { + group.bench_function(format!("power_iteration_damping_{}", damping), |b| { b.iter(|| { let result = naive_ppr( black_box(&graph), - black_box(0.15_f64), + black_box(OrdFloat::from(*damping)), black_box(source_bias.clone()), - black_box(*iterations), + black_box(15), + ); + black_box(result); + }); + }); + + group.bench_function(format!("forward_push_damping_{}", damping), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(*damping)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), ); black_box(result); }); }); } - // 测试不同阻尼因子 - for damping in [0.1, 0.3, 0.5, 0.7, 0.9].iter() { - group.bench_function(format!("damping_{}", damping), |b| { + group.finish(); +} + +/// 中等规模图性能测试 (100-500节点) +/// 测试规模:100, 300, 500个节点的图 +/// 在此规模下Forward Push的性能优势开始显著体现(快100-1000倍) +fn bench_medium_scale_performance(c: &mut Criterion) { + let mut group = c.benchmark_group("medium_scale_performance"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + let graph_sizes = [100, 300, 500]; + + for size in graph_sizes.iter() { + let (graph, source_bias) = prepare_large_test_data(*size); + + // Power Iteration + group.bench_function(format!("power_iteration_size_{}", size), |b| { b.iter(|| { let result = naive_ppr( black_box(&graph), - black_box(*damping as f64), + black_box(OrdFloat::from(0.15)), black_box(source_bias.clone()), - black_box(15), + black_box(10), // 中等规模图使用较少的迭代次数 + ); + black_box(result); + }); + }); + + // Forward Push + group.bench_function(format!("forward_push_size_{}", size), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +/// 大规模图性能测试 (1000-3000节点) +/// 主要测试Forward Push算法,Power Iteration在1000节点以上效率过低 +/// 结果显示Forward Push的执行时间与图规模近似线性增长 +fn bench_large_scale_performance(c: &mut Criterion) { + let mut group = c.benchmark_group("large_scale_performance"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + let graph_sizes = [1000, 2000, 3000]; + + for size in graph_sizes.iter() { + let (graph, source_bias) = prepare_large_test_data(*size); + + // 对于大规模图,只测试Forward Push(Power Iteration可能太慢) + group.bench_function(format!("forward_push_size_{}", size), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), ); black_box(result); }); @@ -156,7 +328,11 @@ fn bench_naive_ppr_parameterized(c: &mut Criterion) { criterion_group!( benches, - bench_naive_ppr_function, - bench_naive_ppr_parameterized + bench_basic_comparison, + bench_power_iteration_variants, + bench_forward_push_variants, + bench_damping_factors, + bench_medium_scale_performance, + bench_large_scale_performance ); criterion_main!(benches); diff --git a/src/utils/graph_algo/ppr.rs b/src/utils/graph_algo/ppr.rs index f793be2..276da6b 100644 --- a/src/utils/graph_algo/ppr.rs +++ b/src/utils/graph_algo/ppr.rs @@ -342,7 +342,7 @@ where }; let source_node_count = D::from_usize(normalized_personalized_vec.len()); - println!("source_node_count: {:?}", source_node_count); + //println!("source_node_count: {:?}", source_node_count); //初始化残差和保留 let mut reserve_vec = vec![D::zero(); graph.node_bound()]; @@ -366,11 +366,11 @@ where //每次取残差最大的节点进行push,加速收敛 while let Some(residue_i) = residue_vec.iter().copied().max() { - println!("Processing node {}", residue_i.idx); + //println!("Processing node {}", residue_i.idx); let out_edges = graph.edges(graph.from_index(residue_i.idx)); //动态归一化的边权计算 if !ppr_edge_weight_cache.contains_key(&graph.from_index(residue_i.idx)) { - println!("Calculating edge weights for node {}", residue_i.idx); + //println!("Calculating edge weights for node {}", residue_i.idx); let weights = out_edges .map(|edge| { let weight = weight_calc(&edge, dynamic_query); @@ -394,7 +394,7 @@ where } let edge_weights = &ppr_edge_weight_cache[&graph.from_index(residue_i.idx)]; - println!("edge_weights: {:?}", edge_weights); + //println!("edge_weights: {:?}", edge_weights); //清空当前节点残差 residue_vec[residue_i.idx].value = D::zero(); @@ -640,11 +640,50 @@ mod test { ) } #[test] - //TODO: pass this test fn ppr_forward_push_toy_graph_init_b() { let (graph, true_ans, indexes) = toy_graph_with_init_b(); let mut source_bias = HashMap::new(); + source_bias.insert(indexes[1], OrdFloat::from_f64(1.0)); + + let ppr_ans = weighted_ppr_fp( + &graph, + OrdFloat::from_f64(0.15), + source_bias, + OrdFloat::from_f64(0.002), + |_, _| OrdFloat::from_f64(1.0), + &"1", + ); + let ans_sum: f64 = ppr_ans + .values() + .copied() + .sum::>() + .into_inner(); + assert!(ans_sum - 1.0 < 1e-5, "the sum is: {ans_sum}"); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual: f64 = ppr_ans[idx].into_inner(); + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_forward_push_toy_graph_init_ab() { + let (graph, true_ans, indexes) = toy_graph_with_init_ab(); + let mut source_bias = HashMap::new(); source_bias.insert(indexes[0], OrdFloat::from_f64(1.0)); + source_bias.insert(indexes[1], OrdFloat::from_f64(1.0)); let ppr_ans = weighted_ppr_fp( &graph, @@ -659,7 +698,7 @@ mod test { .copied() .sum::>() .into_inner(); - assert!(ans_sum - 1.0 < f64::EPSILON); + assert!(ans_sum - 1.0 < 1e-5, "the sum is: {ans_sum}"); let avg_diff = 0.25 * indexes @@ -679,6 +718,7 @@ mod test { true_ans ) } + #[test] fn pressure_large_graph_test() { let (graph, nodes) = pressure_large_graph();