diff --git a/Cargo.lock b/Cargo.lock index 010a3a3..4088e5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,15 @@ dependencies = [ "equator", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -125,6 +134,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstyle" version = "1.0.11" @@ -167,6 +182,12 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arg_enum_proc_macro" version = "0.3.4" @@ -275,7 +296,7 @@ checksum = "fd45deb3dbe5da5cdb8d6a670a7736d735ba65b455328440f236dfb113727a3d" dependencies = [ "Inflector", "async-graphql-parser", - "darling", + "darling 0.20.11", "proc-macro-crate", "proc-macro2", "quote", @@ -725,6 +746,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "castaway" version = "0.2.3" @@ -896,6 +923,37 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" + +[[package]] +name = "cmsketch" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7ee2cfacbd29706479902b06d75ad8f1362900836aa32799eabc7e004bfd854" + [[package]] name = "color_quant" version = "1.1.0" @@ -971,6 +1029,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core_affinity" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a034b3a7b624016c6e13f5df875747cc25f884156aad2abd12b6c46797971342" +dependencies = [ + "libc", + "num_cpus", + "winapi", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -989,6 +1058,41 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" +dependencies = [ + "alloca", + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "itertools 0.13.0", + "num-traits", + "oorandom", + "page_size", + "plotters", + "rayon", + "regex", + "serde", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" +dependencies = [ + "cast", + "itertools 0.13.0", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1053,14 +1157,38 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core 0.14.4", + "darling_macro 0.14.4", +] + [[package]] name = "darling" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", ] [[package]] @@ -1073,17 +1201,28 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.11.1", "syn 2.0.104", ] +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core 0.14.4", + "quote", + "syn 1.0.109", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", "quote", "syn 2.0.104", ] @@ -1151,7 +1290,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -1259,6 +1398,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + [[package]] name = "dtoa" version = "1.0.10" @@ -1417,6 +1562,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "fastant" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e825441bfb2d831c47c97d05821552db8832479f44c571b97fededbf0099c07" +dependencies = [ + "small_ctor", + "web-time", +] + [[package]] name = "fastembed" version = "5.2.0" @@ -1487,6 +1642,18 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8" +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1499,6 +1666,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1529,12 +1702,131 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8866fac38f53fc87fa3ae1b09ddd723e0482f8fa74323518b4c59df2c55a00a" +[[package]] +name = "foyer" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a31f699ce88ac9a53677ca0b1f7a3a902bf3bfae0579e16e86ddf61dee569c0" +dependencies = [ + "anyhow", + "equivalent", + "foyer-common", + "foyer-memory", + "foyer-storage", + "futures-util", + "madsim-tokio", + "mixtrics", + "pin-project", + "serde", + "tokio", + "tracing", +] + +[[package]] +name = "foyer-common" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9ea2c266c9d93ea37c3960f2d0bb625981eefd38120eb06542804bf8169a187" +dependencies = [ + "anyhow", + "bincode", + "bytes", + "cfg-if", + "itertools 0.14.0", + "madsim-tokio", + "mixtrics", + "parking_lot 0.12.4", + "pin-project", + "serde", + "tokio", + "twox-hash", +] + +[[package]] +name = "foyer-intrusive-collections" +version = "0.10.0-dev" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4fee46bea69e0596130e3210e65d3424e0ac1e6df3bde6636304bdf1ca4a3b" +dependencies = [ + "memoffset", +] + +[[package]] +name = "foyer-memory" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09941796e5f8301e82e81e0c9e7514a8443524a461f9a2177500c7525aa73723" +dependencies = [ + "anyhow", + "arc-swap", + "bitflags 2.9.1", + "cmsketch", + "equivalent", + "foyer-common", + "foyer-intrusive-collections", + "futures-util", + "hashbrown 0.16.1", + "itertools 0.14.0", + "madsim-tokio", + "mixtrics", + "parking_lot 0.12.4", + "paste", + "pin-project", + "serde", + "tokio", + "tracing", +] + +[[package]] +name = "foyer-storage" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75fc3db8b685c3eb8b13f05847436933f1f68e91b68510c615d90e9a69541c84" +dependencies = [ + "allocator-api2", + "anyhow", + "bytes", + "core_affinity", + "equivalent", + "fastant", + "flume", + "foyer-common", + "foyer-memory", + "fs4", + "futures-core", + "futures-util", + "hashbrown 0.16.1", + "io-uring", + "itertools 0.14.0", + "libc", + "lz4", + "madsim-tokio", + "parking_lot 0.12.4", + "pin-project", + "rand 0.9.1", + "serde", + "tokio", + "tracing", + "twox-hash", + "zstd", +] + [[package]] name = "fragile" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" +[[package]] +name = "fs4" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" +dependencies = [ + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "fst" version = "0.4.7" @@ -1929,7 +2221,18 @@ checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -2390,6 +2693,17 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "io-uring" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd7bddefd0a8833b88a4b68f90dae22c7450d11b354198baee3874fd811b344" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -2433,6 +2747,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -2672,6 +2995,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + [[package]] name = "lz4-sys" version = "1.11.1+lz4-1.10.0" @@ -2704,6 +3036,61 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" +[[package]] +name = "madsim" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18351aac4194337d6ea9ffbd25b3d1540ecc0754142af1bff5ba7392d1f6f771" +dependencies = [ + "ahash 0.8.12", + "async-channel", + "async-stream", + "async-task", + "bincode", + "bytes", + "downcast-rs", + "errno", + "futures-util", + "lazy_static", + "libc", + "madsim-macros", + "naive-timer", + "panic-message", + "rand 0.8.5", + "rand_xoshiro", + "rustversion", + "serde", + "spin", + "tokio", + "tokio-util", + "toml 0.9.5", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "madsim-macros" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3d248e97b1a48826a12c3828d921e8548e714394bf17274dd0a93910dc946e1" +dependencies = [ + "darling 0.14.4", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "madsim-tokio" +version = "0.2.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d3eb2acc57c82d21d699119b859e2df70a91dbdb84734885a1e72be83bdecb5" +dependencies = [ + "madsim", + "spin", + "tokio", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2774,6 +3161,15 @@ version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "miette" version = "5.10.0" @@ -2840,6 +3236,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mixtrics" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb252c728b9d77c6ef9103f0c81524fa0a3d3b161d0a936295d7fbeff6e04c11" +dependencies = [ + "itertools 0.14.0", + "parking_lot 0.12.4", +] + [[package]] name = "mockall" version = "0.13.1" @@ -2904,6 +3310,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "naive-timer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "034a0ad7deebf0c2abcf2435950a6666c3c15ea9d8fad0c0f48efa8a7f843fed" + [[package]] name = "nalgebra" version = "0.34.0" @@ -2966,6 +3378,15 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom 0.2.16", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -3076,6 +3497,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -3219,6 +3649,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openssl" version = "0.10.73" @@ -3303,6 +3739,22 @@ dependencies = [ "ureq 3.1.2", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "panic-message" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384e52fd8fbd4cbe3c317e8216260c21a0f9134de108cea8a4dd4e7e152c472d" + [[package]] name = "parking" version = "2.2.1" @@ -3558,6 +4010,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.16" @@ -3940,6 +4420,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "rav1e" version = "0.7.1" @@ -4747,7 +5236,7 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -4775,6 +5264,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -4848,6 +5346,12 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" +[[package]] +name = "small_ctor" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88414a5ca1f85d82cc34471e975f0f74f6aa54c40f062efa42c0080e7f763f81" + [[package]] name = "smallvec" version = "1.15.1" @@ -4904,9 +5408,11 @@ dependencies = [ "approx 0.5.1", "async-trait", "chrono", + "criterion", "dotenvy", "fastembed", "formatx", + "foyer", "log", "mockall", "nalgebra", @@ -4946,6 +5452,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "spm_precompiled" @@ -5027,6 +5536,12 @@ dependencies = [ "quote", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strsim" version = "0.11.1" @@ -5177,7 +5692,7 @@ dependencies = [ "sha2", "snap", "storekey", - "strsim", + "strsim 0.11.1", "subtle", "sysinfo", "tempfile", @@ -5459,6 +5974,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.9.0" @@ -5798,6 +6323,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec 1.15.1", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -5838,6 +6389,15 @@ dependencies = [ "utf-8", ] +[[package]] +name = "twox-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" +dependencies = [ + "rand 0.9.1", +] + [[package]] name = "typenum" version = "1.18.0" @@ -6046,6 +6606,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vart" version = "0.8.1" @@ -6766,6 +7332,24 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.15+zstd.1.5.7" diff --git a/Cargo.toml b/Cargo.toml index bc21d69..290e7e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ dotenvy = "0.15.7" tokio-util = "0.7.16" toml = "0.9.5" thiserror = "2.0.17" +foyer = { version = "0.21.1", features = ["serde"] } [profile.release] lto = true @@ -38,3 +39,10 @@ strip = true opt-level = 3 panic = 'abort' codegen-units = 1 + +[dev-dependencies] +criterion = "0.8.1" + +[[bench]] +name = "ppr" +harness = false diff --git a/benches/ppr.rs b/benches/ppr.rs new file mode 100644 index 0000000..ef8af1a --- /dev/null +++ b/benches/ppr.rs @@ -0,0 +1,338 @@ +//! PPR算法性能基准测试 +//! +//! 本文件包含两种PPR(Personalized PageRank)算法的性能对比测试: +//! 1. Power Iteration算法:传统迭代方法,收敛性好但较慢 +//! 2. Forward Push算法:近似算法,速度极快但有一定精度损失 +//! +//! 使用方法: +//! 1. 运行所有测试: cargo bench +//! 2. 运行特定测试: cargo bench <测试组名> +//! 3. 测试结果会生成在target/criterion目录下 + +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use petgraph::stable_graph::NodeIndex; +use petgraph::{prelude::StableDiGraph, visit::NodeIndexable}; +use soul_mem::utils::graph_algo::ord_float::OrdFloat; +use soul_mem::utils::graph_algo::ppr::{naive_ppr, weighted_ppr_fp}; +use std::collections::HashMap; +use std::hint::black_box; + +/// 创建小型测试图的辅助函数(在benchmark计时范围外使用) +/// 使用稀疏图结构以提高测试效率 +fn create_test_graph(size: usize) -> (StableDiGraph>, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + + // 创建指定大小的图 + for i in 0..size { + let node = graph.add_node(format!("node_{}", i)); + nodes.push(node); + + // 每个节点连接到前min(10, 节点数)个节点(如果存在) + nodes.iter().take(10.min(nodes.len())).for_each(|idx| { + graph.add_edge(node, *idx, OrdFloat::from(1.0)); + }); + + // 添加自环 + graph.add_edge(node, node, OrdFloat::from(1.0)); + } + (graph, nodes) +} + +/// 创建大规模测试图的辅助函数(在benchmark计时范围外使用) +fn create_large_test_graph( + size: usize, +) -> (StableDiGraph>, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + + // 创建指定大小的图(优化版本,减少边数以保持可管理性) + for i in 0..size { + let node = graph.add_node(format!("node_{}", i)); + nodes.push(node); + + // 对于大规模图,限制边的数量以提高性能 + let max_connections = if size > 1000 { 5 } else { 10 }; + nodes + .iter() + .take(max_connections.min(nodes.len())) + .for_each(|idx| { + graph.add_edge(node, *idx, OrdFloat::from(1.0)); + }); + + // 添加自环 + graph.add_edge(node, node, OrdFloat::from(1.0)); + } + (graph, nodes) +} + +/// 简单的边权重计算函数(用于forward push算法) +fn simple_weight_calc( + _edge: &petgraph::stable_graph::EdgeReference>, + _query: &(), +) -> OrdFloat { + OrdFloat::from(1.0) +} + +/// 准备小型测试数据(在benchmark外执行) +/// 用于基础性能对比测试 +fn prepare_test_data() -> ( + StableDiGraph>, + HashMap, OrdFloat>, +) { + let (graph, nodes) = create_test_graph(20); + let mut source_bias = HashMap::new(); + + // 前3个节点作为源节点 + nodes.iter().take(3).for_each(|idx| { + source_bias.insert(*idx, OrdFloat::from(graph.to_index(*idx) as f64)); + }); + + (graph, source_bias) +} + +/// 准备大规模测试数据(在benchmark外执行) +/// 专门用于中等和大型规模图的性能测试 +fn prepare_large_test_data( + size: usize, +) -> ( + StableDiGraph>, + HashMap, OrdFloat>, +) { + let (graph, nodes) = create_large_test_graph(size); + let mut source_bias = HashMap::new(); + + // 前min(5, size/100)个节点作为源节点 + let source_count = (5.max(size / 100)).min(nodes.len()); + nodes.iter().take(source_count).for_each(|idx| { + source_bias.insert(*idx, OrdFloat::from(graph.to_index(*idx) as f64)); + }); + + (graph, source_bias) +} + +/// 基础性能对比:Power Iteration vs Forward Push +/// 测试20个节点的小型图,展示两种算法的基本性能差异 +/// 结果显示Forward Push比Power Iteration快约20-30倍 +fn bench_basic_comparison(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + let mut group = c.benchmark_group("basic_comparison"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + // Power Iteration算法 + group.bench_function("power_iteration_15_iters", |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(15), + ); + black_box(result); + }); + }); + + // Forward Push算法 + group.bench_function("forward_push_threshold_1e-4", |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); + }); + + group.finish(); +} + +/// 不同迭代次数下的Power Iteration性能 +/// 展示Power Iteration算法的时间复杂度与迭代次数的关系 +/// 执行时间与迭代次数呈近似线性增长 +fn bench_power_iteration_variants(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + let mut group = c.benchmark_group("power_iteration_variants"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + for iterations in [5, 10, 15, 20].iter() { + group.bench_function(format!("iterations_{}", iterations), |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(*iterations), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +/// 不同残差阈值下的Forward Push性能 +/// 展示阈值对Forward Push算法的精度和速度的影响 +/// 阈值越小,精度越高但执行时间也相应增加 +fn bench_forward_push_variants(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + let mut group = c.benchmark_group("forward_push_variants"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + for threshold in [0.001, 0.0001, 0.00001].iter() { + group.bench_function(format!("threshold_{}", threshold), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(*threshold)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +/// 不同阻尼因子下的性能对比 +/// 测试阻尼因子[0.1, 0.3, 0.5, 0.7]对两种算法的影响 +/// Forward Push在高阻尼因子下性能略有下降 +fn bench_damping_factors(c: &mut Criterion) { + let (graph, source_bias) = prepare_test_data(); + + let mut group = c.benchmark_group("damping_factors"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + for damping in [0.1, 0.3, 0.5, 0.7].iter() { + group.bench_function(format!("power_iteration_damping_{}", damping), |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(OrdFloat::from(*damping)), + black_box(source_bias.clone()), + black_box(15), + ); + black_box(result); + }); + }); + + group.bench_function(format!("forward_push_damping_{}", damping), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(*damping)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +/// 中等规模图性能测试 (100-500节点) +/// 测试规模:100, 300, 500个节点的图 +/// 在此规模下Forward Push的性能优势开始显著体现(快100-1000倍) +fn bench_medium_scale_performance(c: &mut Criterion) { + let mut group = c.benchmark_group("medium_scale_performance"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + let graph_sizes = [100, 300, 500]; + + for size in graph_sizes.iter() { + let (graph, source_bias) = prepare_large_test_data(*size); + + // Power Iteration + group.bench_function(format!("power_iteration_size_{}", size), |b| { + b.iter(|| { + let result = naive_ppr( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(10), // 中等规模图使用较少的迭代次数 + ); + black_box(result); + }); + }); + + // Forward Push + group.bench_function(format!("forward_push_size_{}", size), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +/// 大规模图性能测试 (1000-3000节点) +/// 主要测试Forward Push算法,Power Iteration在1000节点以上效率过低 +/// 结果显示Forward Push的执行时间与图规模近似线性增长 +fn bench_large_scale_performance(c: &mut Criterion) { + let mut group = c.benchmark_group("large_scale_performance"); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + let graph_sizes = [1000, 2000, 3000]; + + for size in graph_sizes.iter() { + let (graph, source_bias) = prepare_large_test_data(*size); + + // 对于大规模图,只测试Forward Push(Power Iteration可能太慢) + group.bench_function(format!("forward_push_size_{}", size), |b| { + b.iter(|| { + let result = weighted_ppr_fp( + black_box(&graph), + black_box(OrdFloat::from(0.15)), + black_box(source_bias.clone()), + black_box(OrdFloat::from(0.0001)), + black_box(simple_weight_calc), + black_box(&()), + ); + black_box(result); + }); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_basic_comparison, + bench_power_iteration_variants, + bench_forward_push_variants, + bench_damping_factors, + bench_medium_scale_performance, + bench_large_scale_performance +); +criterion_main!(benches); diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/lib.rs b/src/lib.rs index bf4bf15..ab818fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ //! This lib crate is inspired from A-mem, an agentic memory system, and HippoRAG. pub mod memory; pub mod utils; +pub mod cache; \ No newline at end of file diff --git a/src/memory.rs b/src/memory.rs index 6df01bd..93484ef 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,4 +1,6 @@ +pub mod algo; pub mod embedding; pub mod memory_cluster; pub mod memory_links; pub mod memory_note; +pub mod working_memory; diff --git a/src/memory/algo.rs b/src/memory/algo.rs new file mode 100644 index 0000000..447f096 --- /dev/null +++ b/src/memory/algo.rs @@ -0,0 +1 @@ +pub mod retrieve; \ No newline at end of file diff --git a/src/memory/algo/retrieve.rs b/src/memory/algo/retrieve.rs new file mode 100644 index 0000000..1f84c30 --- /dev/null +++ b/src/memory/algo/retrieve.rs @@ -0,0 +1,13 @@ +use crate::memory::{embedding::MemoryEmbedding, memory_note::MemoryId}; +pub mod association; +pub mod cached_path; +pub mod deep_thought; +pub mod short_only; +pub mod similarity; + +pub trait RetrStrategy { + type Request: RetrRequest; //接受的查询参数类型 + fn retrieve(&self, request: Self::Request) -> Vec; //TODO:返回类型还没想好,暂定Vec,或许也可以考虑返回迭代器,看具体场景 +} + +pub trait RetrRequest {} diff --git a/src/memory/algo/retrieve/association.rs b/src/memory/algo/retrieve/association.rs new file mode 100644 index 0000000..78b773f --- /dev/null +++ b/src/memory/algo/retrieve/association.rs @@ -0,0 +1,24 @@ +use std::sync::Arc; + +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; + +use super::RetrStrategy; + +//用PPR变种算法进行联想 +pub struct RetrAssociation { + pub max_results: usize, +} +pub struct AssociationRequest { + working_mem: Arc, +} + +impl RetrRequest for AssociationRequest {} + +impl RetrStrategy for RetrAssociation { + type Request = AssociationRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/cached_path.rs b/src/memory/algo/retrieve/cached_path.rs new file mode 100644 index 0000000..98b60fa --- /dev/null +++ b/src/memory/algo/retrieve/cached_path.rs @@ -0,0 +1,24 @@ +//采用dfs,通过边权中的记忆向量来快速扩展子图信息,详见ReMindRAG + +use crate::memory::working_memory::WorkingMemory; +use std::sync::Arc; + +use super::RetrRequest; +use super::RetrStrategy; + +pub struct RetrCachedPath { + pub max_depth: usize, // dfs的最大深度 + pub expand_threshold: f64, //计算向量与查询向量的相似度大于此值,将被扩展 +} + +pub struct CachedPathRequest { + working_mem: Arc, //计算向量与查询向量的相似度大于此值,将被扩展 +} +impl RetrRequest for CachedPathRequest {} + +impl RetrStrategy for RetrCachedPath { + type Request = CachedPathRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/deep_thought.rs b/src/memory/algo/retrieve/deep_thought.rs new file mode 100644 index 0000000..626d6b2 --- /dev/null +++ b/src/memory/algo/retrieve/deep_thought.rs @@ -0,0 +1,23 @@ +use std::sync::Arc; + +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; + +use super::RetrStrategy; +// 采用 LLM进行的Plan-on-Graph +pub struct RetrDeepThought { + pub max_depth: usize, +} +pub struct DeepThoughtRequest { + working_mem: Arc, +} + +impl RetrRequest for DeepThoughtRequest {} + +impl RetrStrategy for RetrDeepThought { + type Request = DeepThoughtRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/short_only.rs b/src/memory/algo/retrieve/short_only.rs new file mode 100644 index 0000000..f224b3c --- /dev/null +++ b/src/memory/algo/retrieve/short_only.rs @@ -0,0 +1,23 @@ +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; +use std::sync::Arc; + +//仅提取短期记忆策略,即仅提取滑动窗口 +use super::RetrStrategy; +pub struct RetrShortOnly { + pub clipping_length: Option, + pub include_summary: bool, +} +pub struct ShortOnlyRequest { + working_mem: Arc, //因为检索算法很可能需要并发执行,使用Arc而非引用确保可以Send +} + +impl RetrRequest for ShortOnlyRequest {} + +impl RetrStrategy for RetrShortOnly { + type Request = ShortOnlyRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + todo!() + } +} diff --git a/src/memory/algo/retrieve/similarity.rs b/src/memory/algo/retrieve/similarity.rs new file mode 100644 index 0000000..ae9f1a8 --- /dev/null +++ b/src/memory/algo/retrieve/similarity.rs @@ -0,0 +1,27 @@ +//仅提取相似记忆策略,即仅提取相似度大于阈值的记忆片段 +use super::RetrStrategy; +use crate::memory::{ + algo::retrieve::RetrRequest, memory_note::MemoryId, working_memory::WorkingMemory, +}; +use std::sync::Arc; +pub struct RetrSimilarity { + pub similarity_threshold: f64, + pub max_results: usize, +} +pub struct SimilarityRequest { + working_mem: Arc, +} +impl RetrRequest for SimilarityRequest {} +impl RetrStrategy for RetrSimilarity { + type Request = SimilarityRequest; + fn retrieve(&self, request: Self::Request) -> Vec { + //TODO: 减少计算量,以下只是一个最初步的实现 + let cos_similarities: Vec<(f64, MemoryId)> = vec![]; + cos_similarities + .into_iter() + .filter(|(similarity, _)| *similarity > self.similarity_threshold) + .map(|(_, id)| id) + .take(self.max_results) + .collect() + } +} diff --git a/src/memory/embedding.rs b/src/memory/embedding.rs index 096d2b0..a3c1a50 100644 --- a/src/memory/embedding.rs +++ b/src/memory/embedding.rs @@ -2,25 +2,86 @@ use thiserror::Error; use crate::memory::memory_note::EmbedMemoryNote; +pub struct EmbeddingVector { + // Placeholder for embedding vector +} +impl EmbeddingVector { + pub fn euclidean_distance(&self, other: &EmbeddingVector) -> Result { + todo!("Euclidean distance") + } + pub fn cosine_similarity(&self, other: &EmbeddingVector) -> Result { + todo!("Cosine similarity") + } + pub fn manhattan_distance(&self, other: &EmbeddingVector) -> Result { + todo!("Manhattan distance") + } +} + pub trait Embeddable { - fn embed(&self, model: &EmbeddingModel) -> Result; - fn embed_vec(&self, model: &EmbeddingModel) -> Result; + fn embed(&self, model: &EmbeddingModel) -> Result; + fn embed_vec(&self, model: &EmbeddingModel) -> Result; } +type EmbedCalcResult = Result; +type EmbeddingGenResult = Result; -//Only a placeholder for now #[derive(Debug, Error)] -pub enum EmbeddingError { - #[error("Invalid input")] +pub enum EmbeddingGenError { + #[error("Invalid input")] //缺失了某些必要字段 InvalidInput, #[error("Embedding failed")] EmbeddingFailed, } +//Only a placeholder for now +#[derive(Debug, Error)] +pub enum EmbeddingCalcError { + #[error("Invalid vec")] //缺失了某些必要字段 + InvalidVec, + #[error("Shape mismatch")] //维度不匹配 + ShapeMismatch, + #[error("Invalid number value")] //数值无效,例如NaN,Inf等 + InvalidNumValue, +} + pub struct EmbeddingModel { // Placeholder for embedding model wrapper } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct MemoryEmbedding { //Placeholder for embedding holder } +impl MemoryEmbedding { + pub fn euclidean_distance( + &self, + other: &MemoryEmbedding, + hyperparams: VecBlendHyperParams, + ) -> Result { + todo!("Euclidean distance") + } + pub fn cosine_similarity( + &self, + other: &MemoryEmbedding, + hyperparams: VecBlendHyperParams, + ) -> Result { + todo!("Cosine similarity") + } + pub fn manhattan_distance( + &self, + other: &MemoryEmbedding, + hyperparams: VecBlendHyperParams, + ) -> Result { + todo!("Manhattan distance") + } +} +#[derive(Debug, Clone, Copy)] +pub struct VecBlendHyperParams { + // Placeholder for vector blending hyperparameters +} +impl Default for VecBlendHyperParams { + fn default() -> Self { + VecBlendHyperParams { + // Placeholder for default values + } + } +} diff --git a/src/memory/memory_note.rs b/src/memory/memory_note.rs index f1bc202..87deaa6 100644 --- a/src/memory/memory_note.rs +++ b/src/memory/memory_note.rs @@ -10,7 +10,8 @@ mod proc_mem; use crate::memory::embedding::Embeddable; use crate::memory::embedding::MemoryEmbedding; -use super::embedding::EmbeddingError; +use super::embedding::EmbeddingCalcError; +use super::embedding::EmbeddingGenError; use super::embedding::EmbeddingModel; use super::memory_links::MemoryLink; @@ -66,10 +67,13 @@ impl MemoryNote { } } impl Embeddable for MemoryNote { - fn embed(&self, embedding_model: &EmbeddingModel) -> Result { + fn embed( + &self, + embedding_model: &EmbeddingModel, + ) -> Result { todo!("Add the embedding logic") } - fn embed_vec(&self, model: &EmbeddingModel) -> Result { + fn embed_vec(&self, model: &EmbeddingModel) -> Result { todo!("Add the embedding logic") } } diff --git a/src/memory/working_memory.rs b/src/memory/working_memory.rs new file mode 100644 index 0000000..30fb4ef --- /dev/null +++ b/src/memory/working_memory.rs @@ -0,0 +1,7 @@ +use crate::memory::memory_cluster::MemoryCluster; + +//代表工作记忆,应当包含记忆子图,短期记忆(滑动窗口),记忆的提取记录等。 +// 占位,后续逐渐增加内容 +pub struct WorkingMemory { + cluster: MemoryCluster, +} diff --git a/src/utils.rs b/src/utils.rs index 717e162..853c8e3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1 +1,2 @@ -pub mod pipe; \ No newline at end of file +pub mod graph_algo; +pub mod pipe; diff --git a/src/utils/graph_algo.rs b/src/utils/graph_algo.rs new file mode 100644 index 0000000..b696c0f --- /dev/null +++ b/src/utils/graph_algo.rs @@ -0,0 +1,2 @@ +pub mod ord_float; +pub mod ppr; diff --git a/src/utils/graph_algo/ord_float.rs b/src/utils/graph_algo/ord_float.rs new file mode 100644 index 0000000..4807a93 --- /dev/null +++ b/src/utils/graph_algo/ord_float.rs @@ -0,0 +1,172 @@ +use core::iter::Sum; +use ordered_float::{FloatCore, OrderedFloat, PrimitiveFloat}; +use petgraph::algo::UnitMeasure; +use std::fmt::Debug; + +#[derive(Debug, Clone, Copy, PartialEq, Hash)] +pub struct OrdFloat(OrderedFloat); +impl OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + pub fn into_inner(self) -> F { + self.0.into_inner() + } +} +impl Default for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + Default, +{ + fn default() -> Self { + OrdFloat(OrderedFloat::default()) + } +} +impl Sum for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + Sum, +{ + fn sum>(iter: I) -> Self { + OrdFloat(OrderedFloat(iter.map(|f| f.0.0).sum())) + } +} +impl std::ops::Add for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Add, +{ + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 + rhs.0) + } +} +impl std::ops::AddAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::AddAssign, +{ + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} +impl std::ops::Sub for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Sub, +{ + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 - rhs.0) + } +} +impl std::ops::SubAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::SubAssign, +{ + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + } +} +impl std::ops::Mul for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Mul, +{ + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 * rhs.0) + } +} +impl std::ops::MulAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::MulAssign, +{ + fn mul_assign(&mut self, rhs: Self) { + self.0 *= rhs.0; + } +} +impl std::ops::Div for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Div, +{ + type Output = Self; + fn div(self, rhs: Self) -> Self::Output { + OrdFloat(self.0 / rhs.0) + } +} +impl std::ops::DivAssign for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::DivAssign, +{ + fn div_assign(&mut self, rhs: Self) { + self.0 /= rhs.0; + } +} +impl std::ops::Neg for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + std::ops::Neg, +{ + type Output = Self; + fn neg(self) -> Self::Output { + OrdFloat(-self.0) + } +} +impl UnitMeasure for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat + Debug + Sum + Default, +{ + fn default_tol() -> Self { + OrdFloat(OrderedFloat((F::epsilon()))) + } + ///如果F不能从f32转化,则变为NaN而非panic + fn from_f32(val: f32) -> Self { + OrdFloat(OrderedFloat(F::from(val).unwrap_or(F::nan()))) + } + ///如果F不能从f64转化,则变为NaN而非panic + fn from_f64(val: f64) -> Self { + OrdFloat(OrderedFloat(F::from(val).unwrap_or(F::nan()))) + } + ///如果F不能从usize转化,则变为NaN而非panic + fn from_usize(nb: usize) -> Self { + OrdFloat(OrderedFloat(F::from(nb).unwrap_or(F::nan()))) + } + fn one() -> Self { + OrdFloat(OrderedFloat(F::one())) + } + fn zero() -> Self { + OrdFloat(OrderedFloat(F::zero())) + } +} +impl From for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + fn from(val: F) -> Self { + OrdFloat(OrderedFloat(val)) + } +} +impl Eq for OrdFloat where F: ordered_float::FloatCore + PrimitiveFloat {} + +impl PartialOrd for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + fn ge(&self, other: &Self) -> bool { + self.0.ge(&other.0) + } + fn gt(&self, other: &Self) -> bool { + self.0.gt(&other.0) + } + fn le(&self, other: &Self) -> bool { + self.0.le(&other.0) + } + fn lt(&self, other: &Self) -> bool { + self.0.lt(&other.0) + } + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} +impl Ord for OrdFloat +where + F: ordered_float::FloatCore + PrimitiveFloat, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} diff --git a/src/utils/graph_algo/ppr.rs b/src/utils/graph_algo/ppr.rs new file mode 100644 index 0000000..276da6b --- /dev/null +++ b/src/utils/graph_algo/ppr.rs @@ -0,0 +1,734 @@ +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashMap}, + fmt::Debug, + hash::Hash, + ops::AddAssign, +}; + +use super::ord_float::OrdFloat; +use petgraph::{ + Direction::Incoming, + algo::UnitMeasure, + visit::{EdgeRef, IntoEdges, IntoNodeIdentifiers, NodeCount, NodeIndexable}, +}; + +///PPR: ppr_s = dampling_factor * P * ppr_s + (1-damping_factor) * personalized_vec, P为转移矩阵 +/// 对无出度的节点,采取与source_bias中的节点建立连接 +/// 必须保证source_bias的key是有效的NodeId, 否则会得到不正确的结果 +// 由于NodeId会由MemoryCluster提供,这不会造成额外的检查负担 +#[track_caller] +pub fn naive_ppr( + graph: G, + damping_factor: D, + personalized_vec: HashMap, + nb_iter: usize, +) -> HashMap +where + G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, + D: UnitMeasure + Copy, + G::NodeId: Hash + Eq, +{ + let node_count = graph.node_count(); + if node_count == 0 { + return HashMap::new(); + } + + //检查阻尼系数 + assert!( + D::zero() <= damping_factor && damping_factor <= D::one(), + "Damping factor should be between 0 et 1." + ); + + //检查个性化分布是不是一个概率分布 + let personalized_sum: D = personalized_vec.values().copied().sum(); + assert!( + personalized_sum > D::zero(), + "Personalized Source bias sum must be positive" + ); + + //归一化个性化向量(初始向量) + let normalized_personalized_vec: HashMap = if personalized_sum != D::one() { + personalized_vec + .into_iter() + .map(|(node_id, bias)| (node_id, bias / personalized_sum)) + .collect() + } else { + personalized_vec + }; + + //图中有效的索引值,适配StableGraph(索引可能不连续) + let valid_index = graph + .node_identifiers() + .map(|node_id| graph.to_index(node_id)) + .collect::>(); + + //ppr值的存储 + //此处可能有大量内存浪费(无效的索引值占位),考虑到工作记忆子图不会过于频繁释放和加载,这个内存开销应该是可以接受的 + let mut ppr_ranks = vec![D::zero(); graph.node_bound()]; + let mut out_degrees = vec![D::zero(); graph.node_bound()]; + + //使用个性化向量,初始化PPR值向量,由于源节点有向量相似性取top-k提供(k通常不大),这样初始化通常可以加快收敛速度 + normalized_personalized_vec + .iter() + .for_each(|(&node_id, &bias)| { + ppr_ranks[graph.to_index(node_id)] = bias; //SAFEUNWRAP: 已经预先分配了索引上限大小的内存,不会越界访问。 + }); + let normalized_bias_len = normalized_personalized_vec.len(); + //println!("normalized_bias: {:?}", normalized_bias); + + //预计算每个节点的出度 + graph.node_identifiers().for_each(|node_id| { + out_degrees[graph.to_index(node_id)] = graph.edges(node_id).map(|_| D::one()).sum(); + }); + //println!("out_degrees: {:?}", out_degrees); + + for i in 0..nb_iter { + let ppr_vec_i = valid_index + .iter() + .map(|&computing_idx| { + let iter_ppr = valid_index + .iter() + .map(|&idx| { + //找到每个节点的出边 + let mut out_edges = graph.edges(graph.from_index(idx)); + + //游走部分的计算,对于无出度节点,默认其连接至所有个性化向量中不为0的节点 + if out_edges.any(|e| e.target() == graph.from_index(computing_idx)) { + damping_factor * ppr_ranks[idx] / out_degrees[idx] + } else if out_degrees[idx] == D::zero() { + normalized_personalized_vec + .get(&graph.from_index(computing_idx)) + .map(|_| { + damping_factor * ppr_ranks[idx] + / D::from_usize(normalized_bias_len) + }) + .unwrap_or(D::zero()) + } else { + D::zero() + } + }) + .sum::(); + + //随机重启部分计算 + let random_back_part = if let Some(per_i) = + normalized_personalized_vec.get(&graph.from_index(computing_idx)) + { + (D::one() - damping_factor) * *per_i + } else { + D::zero() + }; + + (computing_idx, iter_ppr + random_back_part) + }) + .collect::>(); + + // 归一化PPR值,确保数值稳定,总和为1 + let sum = ppr_vec_i.iter().map(|(_, ppr)| *ppr).sum::(); + + ppr_vec_i.iter().for_each(|&(idx, ppr)| { + ppr_ranks[idx] = ppr / sum; + }); + //println!("iteration {i}: PPR values: {:?}", ppr_ranks); + } + + //最终归一化 + let sum = ppr_ranks.iter().map(|ppr| *ppr).sum::(); + + //返回PPR向量,HashMap形式 + graph + .node_identifiers() + .map(|node_id| (node_id, ppr_ranks[graph.to_index(node_id)] / sum)) + .collect() +} + +//残差单元表示 +#[derive(Debug, Clone, Copy)] +struct ResidueUnit { + pub idx: Index, + pub value: D_R, +} +impl PartialOrd for ResidueUnit +where + Index: Copy, + D_R: UnitMeasure + Copy + PartialOrd, +{ + fn ge(&self, other: &Self) -> bool { + self.value.ge(&other.value) + } + fn gt(&self, other: &Self) -> bool { + self.value.gt(&other.value) + } + fn le(&self, other: &Self) -> bool { + self.value.le(&other.value) + } + fn lt(&self, other: &Self) -> bool { + self.value.lt(&other.value) + } + fn partial_cmp(&self, other: &Self) -> Option { + self.value.partial_cmp(&other.value) + } +} +impl PartialEq for ResidueUnit +where + Index: Copy, + D_R: UnitMeasure + Copy + PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } + fn ne(&self, other: &Self) -> bool { + self.value.ne(&other.value) + } +} +impl Eq for ResidueUnit +where + Index: Copy, + D_R: UnitMeasure + Copy + Eq, +{ +} +impl Ord for ResidueUnit +where + Index: Copy, + D_R: UnitMeasure + Copy + Ord + PartialOrd + Eq, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.value.cmp(&other.value) + } + fn clamp(self, min: Self, max: Self) -> Self + where + Self: Sized, + { + Self { + idx: self.idx, + value: self.value.clamp(min.value, max.value), + } + } + fn max(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => other, + Ordering::Equal => other, + Ordering::Greater => self, + } + } + fn min(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => self, + Ordering::Equal => self, + Ordering::Greater => other, + } + } +} + +//边权的单元表示 +#[derive(Debug)] +pub struct EdgeWeightUnit +where + D: UnitMeasure + Copy, +{ + pub target_node: NodeIdx, + pub idx: EdgeIdx, + pub value: D, +} +impl PartialOrd for EdgeWeightUnit +where + D: UnitMeasure + Copy + PartialOrd, +{ + fn ge(&self, other: &Self) -> bool { + self.value.ge(&other.value) + } + fn gt(&self, other: &Self) -> bool { + self.value.gt(&other.value) + } + fn le(&self, other: &Self) -> bool { + self.value.le(&other.value) + } + fn lt(&self, other: &Self) -> bool { + self.value.lt(&other.value) + } + fn partial_cmp(&self, other: &Self) -> Option { + self.value.partial_cmp(&other.value) + } +} +impl PartialEq for EdgeWeightUnit +where + D: UnitMeasure + Copy + PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } + fn ne(&self, other: &Self) -> bool { + self.value.ne(&other.value) + } +} +impl Eq for EdgeWeightUnit where D: UnitMeasure + Copy + Eq +{} + +impl Ord for EdgeWeightUnit +where + D: UnitMeasure + Copy + Ord + PartialOrd + Eq, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.value.cmp(&other.value) + } + fn clamp(self, min: Self, max: Self) -> Self + where + Self: Sized, + { + Self { + target_node: self.target_node, + idx: self.idx, + value: self.value.clamp(min.value, max.value), + } + } + fn max(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => other, + Ordering::Equal => other, + Ordering::Greater => self, + } + } + fn min(self, other: Self) -> Self + where + Self: Sized, + { + match self.cmp(&other) { + Ordering::Less => self, + Ordering::Equal => self, + Ordering::Greater => other, + } + } +} + +#[track_caller] +//TODO: make damping factor specific to each node +pub fn weighted_ppr_fp( + graph: G, + damping_factor: D, + personalized_vec: HashMap, + residue_threshold: D, + weight_calc: impl Fn(&G::EdgeRef, &Q) -> D, + dynamic_query: &Q, +) -> HashMap +where + G: NodeCount + IntoEdges + NodeIndexable + IntoNodeIdentifiers, + D: UnitMeasure + Copy + AddAssign + Ord, + G::NodeId: Hash + Eq + Debug, //TODO: delete the Debug Trait bound + G::EdgeId: Hash + Eq + Debug, +{ + //归一化个性化向量 + let personalized_sum = personalized_vec.values().copied().sum::(); + assert!( + personalized_sum > D::zero(), + "Personalized Source bias sum must be positive" + ); + + let normalized_personalized_vec: HashMap = if personalized_sum != D::one() { + personalized_vec + .into_iter() + .map(|(node_id, bias)| (node_id, bias / personalized_sum)) + .collect() + } else { + personalized_vec + }; + + let source_node_count = D::from_usize(normalized_personalized_vec.len()); + //println!("source_node_count: {:?}", source_node_count); + + //初始化残差和保留 + let mut reserve_vec = vec![D::zero(); graph.node_bound()]; + let mut residue_vec = (0..graph.node_bound()) + .map(|i| { + let residue_i = normalized_personalized_vec + .get(&graph.from_index(i)) + .copied() + .unwrap_or(D::zero()); + ResidueUnit { + idx: i, + value: residue_i, + } + }) + .collect::>(); + + let mut ppr_edge_weight_cache: HashMap< + G::NodeId, + Vec>, + > = HashMap::with_capacity(graph.node_count()); + + //每次取残差最大的节点进行push,加速收敛 + while let Some(residue_i) = residue_vec.iter().copied().max() { + //println!("Processing node {}", residue_i.idx); + let out_edges = graph.edges(graph.from_index(residue_i.idx)); + //动态归一化的边权计算 + if !ppr_edge_weight_cache.contains_key(&graph.from_index(residue_i.idx)) { + //println!("Calculating edge weights for node {}", residue_i.idx); + let weights = out_edges + .map(|edge| { + let weight = weight_calc(&edge, dynamic_query); + EdgeWeightUnit { + target_node: edge.target(), + idx: edge.id(), + value: weight, + } + }) + .collect::>(); + let sum = weights.iter().map(|v| v.value).sum::(); + let weights = weights + .into_iter() + .map(|w| EdgeWeightUnit { + target_node: w.target_node, + idx: w.idx, + value: w.value / sum, + }) + .collect::>(); + ppr_edge_weight_cache.insert(graph.from_index(residue_i.idx), weights); + } + + let edge_weights = &ppr_edge_weight_cache[&graph.from_index(residue_i.idx)]; + //println!("edge_weights: {:?}", edge_weights); + //清空当前节点残差 + residue_vec[residue_i.idx].value = D::zero(); + + //将部分残差转为保留 + reserve_vec[residue_i.idx] += (D::one() - damping_factor) * residue_i.value; + + //残差push + if let Some(edge_weight_max) = edge_weights.iter().max() { + //节点出度不为0的情况 + if residue_i.value * edge_weight_max.value > residue_threshold { + edge_weights.iter().for_each(|edge_w| { + residue_vec[graph.to_index(edge_w.target_node)].value += + damping_factor * edge_w.value * residue_i.value; + }); + } else { + break; + } + } else { + //节点出度为0的情况 + if residue_i.value / source_node_count > residue_threshold { + normalized_personalized_vec.keys().for_each(|node| { + residue_vec[graph.to_index(*node)].value += + damping_factor * residue_i.value / source_node_count; + }); + } else { + break; + } + } + } + let sum = reserve_vec.iter().copied().sum::(); + + graph + .node_identifiers() + .map(|node| { + let ppr_value = reserve_vec[graph.to_index(node)] / sum; + (node, ppr_value) + }) + .collect() +} + +#[cfg(test)] +mod test { + + use mockall::predicate::float; + use petgraph::{matrix_graph::NodeIndex, prelude::StableDiGraph}; + + use super::*; + fn diff(actual: f64, expected: f64) -> f64 { + if expected.abs() < f64::EPSILON && actual.abs() < f64::EPSILON { + 0.0 + } else { + let diff = (actual - expected).abs(); + diff + } + } + fn pressure_large_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let mut nodes = Vec::new(); + for i in 0..500 { + let mut node = graph.add_node("".to_string()); + if i % 2 == 0 || i % 7 == 0 { + graph.remove_node(node); + node = graph.add_node("".to_string()); + } + nodes.push(node); + graph.add_edge(node, node, 1.0); + nodes.iter().for_each(|idx| { + graph.add_edge(node, *idx, 1.0); + }); + } + (graph, nodes) + } + + fn test_toy_graph() -> (StableDiGraph, Vec>) { + let mut graph = StableDiGraph::new(); + let a = graph.add_node("A".to_string()); + let b = graph.add_node("B".to_string()); + //制造索引空洞 + graph.remove_node(b); + let b = graph.add_node("B".to_string()); + let c = graph.add_node("C".to_string()); + let d = graph.add_node("D".to_string()); + + graph.add_edge(a, b, 1.0); + graph.add_edge(a, c, 1.0); + graph.add_edge(b, c, 1.0); + graph.add_edge(c, d, 1.0); + + (graph, vec![a, b, c, d]) + } + fn toy_graph_with_init_a() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.851652742, 0.06387396045, 0.07345504972, 0.01101824785]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + fn toy_graph_with_init_b() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.0, 0.852878432, 0.1279320211, 0.01918954688]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + fn toy_graph_with_init_ab() -> ( + StableDiGraph, + HashMap, f64>, + Vec>, + ) { + let (graph, indexes) = test_toy_graph(); + let ans_vec: Vec = vec![0.4261326137, 0.4580925718, 0.1006738318, 0.00510098267]; + let ans = indexes.iter().copied().zip(ans_vec).collect(); + (graph, ans, indexes) + } + #[test] + fn ppr_toy_graph_init_a() { + let (graph, true_ans, indexes) = toy_graph_with_init_a(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_toy_graph_init_b() { + let (graph, true_ans, indexes) = toy_graph_with_init_b(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[1], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_toy_graph_init_ab() { + let (graph, true_ans, indexes) = toy_graph_with_init_ab(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], 1.0); + source_bias.insert(indexes[1], 1.0); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual = ppr_ans[idx]; + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_forward_push_toy_graph_init_a() { + let (graph, true_ans, indexes) = toy_graph_with_init_a(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], OrdFloat::from_f64(1.0)); + + let ppr_ans = weighted_ppr_fp( + &graph, + OrdFloat::from_f64(0.15), + source_bias, + OrdFloat::from_f64(0.002), + |_, _| OrdFloat::from_f64(1.0), + &"1", + ); + let ans_sum: f64 = ppr_ans + .values() + .copied() + .sum::>() + .into_inner(); + assert!(ans_sum - 1.0 < f64::EPSILON); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual: f64 = ppr_ans[idx].into_inner(); + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_forward_push_toy_graph_init_b() { + let (graph, true_ans, indexes) = toy_graph_with_init_b(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[1], OrdFloat::from_f64(1.0)); + + let ppr_ans = weighted_ppr_fp( + &graph, + OrdFloat::from_f64(0.15), + source_bias, + OrdFloat::from_f64(0.002), + |_, _| OrdFloat::from_f64(1.0), + &"1", + ); + let ans_sum: f64 = ppr_ans + .values() + .copied() + .sum::>() + .into_inner(); + assert!(ans_sum - 1.0 < 1e-5, "the sum is: {ans_sum}"); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual: f64 = ppr_ans[idx].into_inner(); + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + #[test] + fn ppr_forward_push_toy_graph_init_ab() { + let (graph, true_ans, indexes) = toy_graph_with_init_ab(); + let mut source_bias = HashMap::new(); + source_bias.insert(indexes[0], OrdFloat::from_f64(1.0)); + source_bias.insert(indexes[1], OrdFloat::from_f64(1.0)); + + let ppr_ans = weighted_ppr_fp( + &graph, + OrdFloat::from_f64(0.15), + source_bias, + OrdFloat::from_f64(0.002), + |_, _| OrdFloat::from_f64(1.0), + &"1", + ); + let ans_sum: f64 = ppr_ans + .values() + .copied() + .sum::>() + .into_inner(); + assert!(ans_sum - 1.0 < 1e-5, "the sum is: {ans_sum}"); + + let avg_diff = 0.25 + * indexes + .iter() + .map(|idx| { + let actual: f64 = ppr_ans[idx].into_inner(); + let expected = true_ans[idx]; + diff(actual, expected) + }) + .sum::(); + + assert!( + avg_diff < 0.005, + "failed with avg_diff {}, whole ppr_vec is : {:?}, but it should be : {:?}", + avg_diff, + ppr_ans, + true_ans + ) + } + + #[test] + fn pressure_large_graph_test() { + let (graph, nodes) = pressure_large_graph(); + let mut source_bias = HashMap::new(); + nodes.iter().take(10).for_each(|idx| { + source_bias.insert(*idx, graph.to_index(*idx) as f64); + }); + + let ppr_ans = naive_ppr(&graph, 0.15_f64, source_bias, 15); + let ans_sum = ppr_ans.values().copied().sum::(); + assert!(ans_sum - 1.0 < f64::EPSILON); + } +}