diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e3f02c..146abe3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,9 @@ All notable changes to this project will be documented in this file. ### New features +* Implement `once::OnceMap` to run computation only once and store the results in a hash map. * `singleflight::Group` now supports custom hashers for keys. +* `singleflight::Group::remove` now accepts any `&Q` where `Q: ?Sized + Hash + Eq` and `K: Borrow` aligning with standard HashMap's interface. ## [0.6.1] - 2026-01-11 diff --git a/README.md b/README.md index 57104ce..bb27b48 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Mea (Make Easy Async) is a runtime-agnostic library providing essential synchron * [**Mutex**](https://docs.rs/mea/*/mea/mutex/struct.Mutex.html): A mutual exclusion primitive for protecting shared data. * [**Once**](https://docs.rs/mea/*/mea/once/struct.Once.html): A primitive that ensures a one-time asynchronous operation runs at most once, even when called concurrently. * [**OnceCell**](https://docs.rs/mea/*/mea/once/struct.OnceCell.html): A cell that can be written to at most once, providing safe, lazy initialization. +* [**OnceMap**](https://docs.rs/mea/*/mea/once/struct.OnceMap.html): A hash map that runs computation only once for each key and stores the result. * [**RwLock**](https://docs.rs/mea/*/mea/rwlock/struct.RwLock.html): A reader-writer lock that allows multiple readers or a single writer at a time. * [**Semaphore**](https://docs.rs/mea/*/mea/semaphore/struct.Semaphore.html): A synchronization primitive that controls access to a shared resource. * [**ShutdownSend & ShutdownRecv**](https://docs.rs/mea/*/mea/shutdown/): A composite synchronization primitive for managing shutdown signals. @@ -74,6 +75,7 @@ This crate collects runtime-agnostic synchronization primitives from spare parts * **Latch** is inspired by [`latches`](https://github.com/mirromutth/latches), with a different implementation based on the internal `CountdownState` primitive. No `wait` or `watch` method is provided, since it can be easily implemented by [composing delay futures](https://docs.rs/fastimer/*/fastimer/fn.timeout.html). No sync variant is provided, since it can be easily implemented with block_on of any runtime. * **Mutex** is derived from `tokio::sync::Mutex`. No blocking method is provided, since it can be easily implemented with block_on of any runtime. * **OnceCell** is derived from `tokio::sync::OnceCell`, but using our own semaphore implementation. +* **OnceMap** is inspired by `uv-once-map` but the interface and implementation are redesigned. * **RwLock** is derived from `tokio::sync::RwLock`, but the `max_readers` can be any `NonZeroUsize` (effectively any positive `usize`) instead of `[0, u32::MAX >> 3]`. No blocking method is provided, since it can be easily implemented with block_on of any runtime. * **Semaphore** is derived from `tokio::sync::Semaphore`, without `close` method since it is quite tricky to use. And thus, this semaphore doesn't have the limitation of max permits. Besides, new methods like `forget_exact` are added to fit the specific use case. * **WaitGroup** is inspired by [`waitgroup-rs`](https://github.com/laizy/waitgroup-rs), providing different API flavor with a different implementation based on the internal `CountdownState` primitive. diff --git a/mea/src/lib.rs b/mea/src/lib.rs index 54725eb..111868b 100644 --- a/mea/src/lib.rs +++ b/mea/src/lib.rs @@ -29,6 +29,7 @@ //! * [`Once`]: A primitive that ensures a one-time asynchronous operation runs at most once, even //! when called concurrently //! * [`OnceCell`]: A cell that can be written to at most once and provides safe concurrent access +//! * [`OnceMap`]: A hash map that runs computation only once for each key and stores the result. //! * [`RwLock`]: A reader-writer lock that allows multiple readers or a single writer at a time //! * [`Semaphore`]: A synchronization primitive that controls access to a shared resource //! * [`ShutdownSend`] & [`ShutdownRecv`]: A composite synchronization primitive for managing @@ -61,6 +62,7 @@ //! [`Mutex`]: mutex::Mutex //! [`Once`]: once::Once //! [`OnceCell`]: once::OnceCell +//! [`OnceMap`]: once::OnceMap //! [`RwLock`]: rwlock::RwLock //! [`Semaphore`]: semaphore::Semaphore //! [`ShutdownSend`]: shutdown::ShutdownSend diff --git a/mea/src/once/mod.rs b/mea/src/once/mod.rs index b4e9f5f..095b9c0 100644 --- a/mea/src/once/mod.rs +++ b/mea/src/once/mod.rs @@ -24,9 +24,8 @@ #[allow(clippy::module_inception)] mod once; mod once_cell; +mod once_map; pub use self::once::Once; pub use self::once_cell::OnceCell; - -#[cfg(test)] -mod tests; +pub use self::once_map::OnceMap; diff --git a/mea/src/once/once.rs b/mea/src/once/once/mod.rs similarity index 99% rename from mea/src/once/once.rs rename to mea/src/once/once/mod.rs index e3e17e7..cb5efc3 100644 --- a/mea/src/once/once.rs +++ b/mea/src/once/once/mod.rs @@ -20,6 +20,9 @@ use std::task::Poll; use crate::internal::CountdownState; use crate::semaphore::Semaphore; +#[cfg(test)] +mod tests; + /// A synchronization primitive which can be used to run a one-time async initialization. /// /// Unlike [`std::sync::Once`], this type never blocks a thread. The provided closure must diff --git a/mea/src/once/once/tests.rs b/mea/src/once/once/tests.rs new file mode 100644 index 0000000..282c252 --- /dev/null +++ b/mea/src/once/once/tests.rs @@ -0,0 +1,179 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use tokio_test::assert_ready; + +use super::*; +use crate::latch::Latch; +use crate::test_runtime; + +#[tokio::test] +async fn test_call_once_runs_only_once() { + static ONCE: Once = Once::new(); + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + assert!(!ONCE.is_completed()); + + ONCE.call_once(async || { + COUNTER.fetch_add(1, Ordering::SeqCst); + }) + .await; + + assert!(ONCE.is_completed()); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + // Second call should not run the closure + ONCE.call_once(async || { + COUNTER.fetch_add(1, Ordering::SeqCst); + }) + .await; + + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); +} + +#[test] +fn test_once_multi_task() { + static ONCE: Once = Once::new(); + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + test_runtime().block_on(async { + const N: usize = 100; + + let latch = Arc::new(Latch::new(N as u32)); + let mut handles = Vec::with_capacity(N); + + for _ in 0..N { + let latch = latch.clone(); + handles.push(tokio::spawn(async move { + ONCE.call_once(async || { + COUNTER.fetch_add(1, Ordering::SeqCst); + }) + .await; + latch.count_down(); + })); + } + + latch.wait().await; + + for handle in handles { + handle.await.unwrap(); + } + + // Only one task should have incremented the counter + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + assert!(ONCE.is_completed()); + }); +} + +#[tokio::test] +async fn test_once_cancelled() { + static ONCE: Once = Once::new(); + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let handle1 = tokio::spawn(async { + let fut = ONCE.call_once(async || { + tokio::time::sleep(Duration::from_millis(1000)).await; + COUNTER.fetch_add(1, Ordering::SeqCst); + }); + let timeout = tokio::time::timeout(Duration::from_millis(1), fut).await; + assert!(timeout.is_err()); + }); + + let handle2 = tokio::spawn(async { + tokio::time::sleep(Duration::from_millis(100)).await; + ONCE.call_once(async || { + COUNTER.fetch_add(10, Ordering::SeqCst); + }) + .await; + }); + + handle1.await.unwrap(); + handle2.await.unwrap(); + + // The second task should have run since the first was cancelled + assert_eq!(COUNTER.load(Ordering::SeqCst), 10); + assert!(ONCE.is_completed()); +} + +#[tokio::test] +async fn test_once_debug() { + let once = Once::new(); + let debug_str = format!("{:?}", once); + assert!(debug_str.contains("Once")); + assert!(debug_str.contains("done")); + assert!(debug_str.contains("false")); + + once.call_once(async || {}).await; + + let debug_str = format!("{:?}", once); + assert!(debug_str.contains("true")); +} + +#[tokio::test] +async fn test_once_default() { + let once = Once::default(); + assert!(!once.is_completed()); +} + +#[tokio::test] +async fn test_once_retry_after_panic() { + static ONCE: Once = Once::new(); + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let handle = tokio::spawn(async { + ONCE.call_once(async || { + COUNTER.fetch_add(1, Ordering::SeqCst); + panic!("boom"); + }) + .await; + }); + + let err = handle.await.expect_err("once init should panic"); + assert!(err.is_panic()); + + ONCE.call_once(async || { + COUNTER.fetch_add(1, Ordering::SeqCst); + }) + .await; + + assert_eq!(COUNTER.load(Ordering::SeqCst), 2); + assert!(ONCE.is_completed()); +} + +#[tokio::test] +async fn test_once_wait() { + // wait after call_once completed + { + let once = Once::new(); + once.call_once(async || {}).await; + assert_ready!(tokio_test::task::spawn(once.wait()).poll()); + } + + // wait before call_once completed + { + static ONCE: Once = Once::new(); + let handle = tokio::spawn(async { + ONCE.wait().await; + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + ONCE.call_once(async || {}).await; + handle.await.unwrap(); + } +} diff --git a/mea/src/once/once_cell.rs b/mea/src/once/once_cell/mod.rs similarity index 99% rename from mea/src/once/once_cell.rs rename to mea/src/once/once_cell/mod.rs index 65b0a29..a104703 100644 --- a/mea/src/once/once_cell.rs +++ b/mea/src/once/once_cell/mod.rs @@ -22,6 +22,9 @@ use std::sync::atomic::Ordering; use crate::semaphore::Semaphore; use crate::semaphore::SemaphorePermit; +#[cfg(test)] +mod tests; + /// A thread-safe cell which can nominally be written to only once. /// /// # Examples diff --git a/mea/src/once/tests.rs b/mea/src/once/once_cell/tests.rs similarity index 56% rename from mea/src/once/tests.rs rename to mea/src/once/once_cell/tests.rs index fb33ec3..0279de4 100644 --- a/mea/src/once/tests.rs +++ b/mea/src/once/once_cell/tests.rs @@ -14,17 +14,13 @@ use std::sync::Arc; use std::sync::atomic::AtomicBool; -use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::time::Duration; use tokio::sync::Mutex; -use tokio_test::assert_ready; -use super::Once; -use super::once_cell::OnceCell; +use super::*; use crate::latch::Latch; -use crate::test_runtime; struct Foo { value: Arc, @@ -204,158 +200,3 @@ async fn get_mut_or_try_init() { .unwrap(); assert_eq!(v, 15); } - -#[tokio::test] -async fn test_call_once_runs_only_once() { - static ONCE: Once = Once::new(); - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - assert!(!ONCE.is_completed()); - - ONCE.call_once(async || { - COUNTER.fetch_add(1, Ordering::SeqCst); - }) - .await; - - assert!(ONCE.is_completed()); - assert_eq!(COUNTER.load(Ordering::SeqCst), 1); - - // Second call should not run the closure - ONCE.call_once(async || { - COUNTER.fetch_add(1, Ordering::SeqCst); - }) - .await; - - assert_eq!(COUNTER.load(Ordering::SeqCst), 1); -} - -#[test] -fn test_once_multi_task() { - static ONCE: Once = Once::new(); - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - test_runtime().block_on(async { - const N: usize = 100; - - let latch = Arc::new(Latch::new(N as u32)); - let mut handles = Vec::with_capacity(N); - - for _ in 0..N { - let latch = latch.clone(); - handles.push(tokio::spawn(async move { - ONCE.call_once(async || { - COUNTER.fetch_add(1, Ordering::SeqCst); - }) - .await; - latch.count_down(); - })); - } - - latch.wait().await; - - for handle in handles { - handle.await.unwrap(); - } - - // Only one task should have incremented the counter - assert_eq!(COUNTER.load(Ordering::SeqCst), 1); - assert!(ONCE.is_completed()); - }); -} - -#[tokio::test] -async fn test_once_cancelled() { - static ONCE: Once = Once::new(); - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let handle1 = tokio::spawn(async { - let fut = ONCE.call_once(async || { - tokio::time::sleep(Duration::from_millis(1000)).await; - COUNTER.fetch_add(1, Ordering::SeqCst); - }); - let timeout = tokio::time::timeout(Duration::from_millis(1), fut).await; - assert!(timeout.is_err()); - }); - - let handle2 = tokio::spawn(async { - tokio::time::sleep(Duration::from_millis(100)).await; - ONCE.call_once(async || { - COUNTER.fetch_add(10, Ordering::SeqCst); - }) - .await; - }); - - handle1.await.unwrap(); - handle2.await.unwrap(); - - // The second task should have run since the first was cancelled - assert_eq!(COUNTER.load(Ordering::SeqCst), 10); - assert!(ONCE.is_completed()); -} - -#[tokio::test] -async fn test_once_debug() { - let once = Once::new(); - let debug_str = format!("{:?}", once); - assert!(debug_str.contains("Once")); - assert!(debug_str.contains("done")); - assert!(debug_str.contains("false")); - - once.call_once(async || {}).await; - - let debug_str = format!("{:?}", once); - assert!(debug_str.contains("true")); -} - -#[tokio::test] -async fn test_once_default() { - let once = Once::default(); - assert!(!once.is_completed()); -} - -#[tokio::test] -async fn test_once_retry_after_panic() { - static ONCE: Once = Once::new(); - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let handle = tokio::spawn(async { - ONCE.call_once(async || { - COUNTER.fetch_add(1, Ordering::SeqCst); - panic!("boom"); - }) - .await; - }); - - let err = handle.await.expect_err("once init should panic"); - assert!(err.is_panic()); - - ONCE.call_once(async || { - COUNTER.fetch_add(1, Ordering::SeqCst); - }) - .await; - - assert_eq!(COUNTER.load(Ordering::SeqCst), 2); - assert!(ONCE.is_completed()); -} - -#[tokio::test] -async fn test_once_wait() { - // wait after call_once completed - { - let once = Once::new(); - once.call_once(async || {}).await; - assert_ready!(tokio_test::task::spawn(once.wait()).poll()); - } - - // wait before call_once completed - { - static ONCE: Once = Once::new(); - let handle = tokio::spawn(async { - ONCE.wait().await; - }); - - tokio::time::sleep(Duration::from_millis(100)).await; - ONCE.call_once(async || {}).await; - handle.await.unwrap(); - } -} diff --git a/mea/src/once/once_map/mod.rs b/mea/src/once/once_map/mod.rs new file mode 100644 index 0000000..1515bb1 --- /dev/null +++ b/mea/src/once/once_map/mod.rs @@ -0,0 +1,191 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Borrow; +use std::collections::HashMap; +use std::hash::BuildHasher; +use std::hash::Hash; +use std::hash::RandomState; +use std::sync::Arc; + +use crate::internal::Mutex; +use crate::once::OnceCell; + +#[cfg(test)] +mod tests; + +/// A hash map that runs computation only once for each key and stores the result. +/// +/// Note that this always clones the value out of the underlying map. Because of this, it's common +/// to wrap the `V` in an `Arc` to make cloning cheap. +#[derive(Debug)] +pub struct OnceMap { + map: Mutex>, S>>, +} + +impl Default for OnceMap +where + K: Eq + Hash + Clone, + V: Clone, + S: BuildHasher + Clone + Default, +{ + fn default() -> Self { + Self::with_hasher(S::default()) + } +} + +impl OnceMap +where + K: Eq + Hash + Clone, + V: Clone, +{ + /// Creates a new OnceMap with the default hasher. + pub fn new() -> Self { + Self { + map: Mutex::new(HashMap::new()), + } + } + + /// Creates a new OnceMap with the default hasher and the specified capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { + map: Mutex::new(HashMap::with_capacity(capacity)), + } + } +} + +impl OnceMap +where + K: Eq + Hash + Clone, + V: Clone, + S: BuildHasher + Clone, +{ + /// Creates a new OnceMap with the given hasher. + pub fn with_hasher(hasher: S) -> Self { + Self { + map: Mutex::new(HashMap::with_hasher(hasher)), + } + } + + /// Create a OnceMap with the specified capacity and hasher. + pub fn with_capacity_and_hasher(capacity: usize, hasher: S) -> Self { + Self { + map: Mutex::new(HashMap::with_capacity_and_hasher(capacity, hasher)), + } + } + + /// Compute the value for the given key if absent. + /// + /// If the value for the key is already being computed by another task, this task will wait for + /// the computation to finish and return the result. + pub async fn compute(&self, key: K, func: F) -> V + where + F: AsyncFnOnce() -> V, + { + // 1. Get or create the OnceCell. + let cell = { + let mut map = self.map.lock(); + map.entry(key.clone()) + .or_insert_with(|| Arc::new(OnceCell::new())) + .clone() + }; + + // 2. Try to initialize the cell. + // OnceCell::get_or_init guarantees that only one task executes the closure. + let res = cell.get_or_init(func).await; + res.clone() + } + + /// Compute the value for the given key if absent. + /// + /// If the value for the key is already being computed by another task, this task will wait for + /// the computation to finish and return the result. + /// + /// If the computation fails, the error is returned and the value is not stored. Other tasks + /// waiting for the value will retry the computation. + pub async fn try_compute(&self, key: K, func: F) -> Result + where + F: AsyncFnOnce() -> Result, + { + // 1. Get or create the OnceCell. + let cell = { + let mut map = self.map.lock(); + map.entry(key.clone()) + .or_insert_with(|| Arc::new(OnceCell::new())) + .clone() + }; + + // 2. Try to initialize the cell. + // OnceCell::get_or_try_init guarantees that only one task executes the closure. + let res = cell.get_or_try_init(func).await?; + Ok(res.clone()) + } + + /// Get a clone of the value for the given key if exists. + pub fn get(&self, key: &Q) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let map = self.map.lock(); + let cell = map.get(key)?; + cell.get().cloned() + } + + /// Remove the given key from the map. + /// + /// If you need to get the value that has been remove, use the [`remove`] method instead. + /// + /// [`remove`]: Self::remove + pub fn discard(&self, key: &Q) + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let mut map = self.map.lock(); + map.remove(key); + } + + /// Remove the given key from the map and return a *clone* of the value if exists. + /// + /// If you do not need to get the value that has been removed, use the [`discard`] method + /// instead. + /// + /// [`discard`]: Self::discard + pub fn remove(&self, key: &Q) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let cell = self.map.lock().remove(key)?; + cell.get().cloned() + } +} + +impl FromIterator<(K, V)> for OnceMap +where + K: Eq + Hash + Clone, + V: Clone, + S: Default + BuildHasher + Clone, +{ + fn from_iter>(iter: T) -> Self { + Self { + map: Mutex::new( + iter.into_iter() + .map(|(k, v)| (k, Arc::new(OnceCell::from_value(v)))) + .collect(), + ), + } + } +} diff --git a/mea/src/once/once_map/tests.rs b/mea/src/once/once_map/tests.rs new file mode 100644 index 0000000..a6749aa --- /dev/null +++ b/mea/src/once/once_map/tests.rs @@ -0,0 +1,97 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use crate::once::OnceMap; + +#[tokio::test] +async fn test_compute() { + let map = OnceMap::new(); + let v = map.compute("key", async || 1).await; + assert_eq!(v, 1); + let v = map.compute("key", async || 2).await; + assert_eq!(v, 1); +} + +#[tokio::test] +async fn test_compute_concurrent() { + let map = Arc::new(OnceMap::new()); + let cnt = Arc::new(AtomicUsize::new(0)); + let mut handles = Vec::new(); + + for _ in 0..10 { + let map = map.clone(); + let cnt = cnt.clone(); + handles.push(tokio::spawn(async move { + map.compute("key", async move || { + cnt.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(10)).await; + 42 + }) + .await + })); + } + + for h in handles { + assert_eq!(h.await.unwrap(), 42); + } + assert_eq!(cnt.load(Ordering::SeqCst), 1); +} + +#[tokio::test] +async fn test_try_compute() { + let map = OnceMap::new(); + + // Fail first + let res: Result = map.try_compute("key", async || Err("fail")).await; + assert_eq!(res, Err("fail")); + + // Success then + let res: Result = map.try_compute("key", async || Ok::(1)).await; + assert_eq!(res, Ok(1)); + + // Cached + let res: Result = map.try_compute("key", async || Ok::(2)).await; + assert_eq!(res, Ok(1)); +} + +#[tokio::test] +async fn test_get_remove() { + let map = OnceMap::new(); + assert_eq!(map.get("key"), None); + assert_eq!(map.remove("key"), None); + + map.compute("key", async || 1).await; + assert_eq!(map.get("key"), Some(1)); + + let v = map.remove("key"); + assert_eq!(v, Some(1)); + assert_eq!(map.get("key"), None); + + map.compute("key", async || 2).await; + map.discard("key"); + assert_eq!(map.get("key"), None); +} + +#[tokio::test] +async fn test_from_iter() { + let map: OnceMap<_, _> = vec![("a", 1), ("b", 2)].into_iter().collect(); + assert_eq!(map.get("a"), Some(1)); + assert_eq!(map.get("b"), Some(2)); + assert_eq!(map.get("c"), None); +} diff --git a/mea/src/singleflight/mod.rs b/mea/src/singleflight/mod.rs index 8d5ab1a..83b59bc 100644 --- a/mea/src/singleflight/mod.rs +++ b/mea/src/singleflight/mod.rs @@ -14,6 +14,7 @@ //! Singleflight provides a duplicate function call suppression mechanism. +use std::borrow::Borrow; use std::collections::HashMap; use std::hash::BuildHasher; use std::hash::Hash; @@ -135,7 +136,6 @@ where // OnceCell::get_or_init guarantees that only one task executes the closure. let res = cell .get_or_init(async || { - // I am the leader. let result = func().await; // Cleanup: remove the key from the map. @@ -155,11 +155,96 @@ where res.clone() } + /// Executes and returns the results of the given function, making sure that only one execution + /// is in-flight for a given key at a time. + /// + /// If a duplicate comes in, the duplicate caller waits for the original to complete and + /// receives the same results. + /// + /// If the computation fails, the error is returned for the caller. Other tasks waiting for the + /// result will retry the computation. + /// + /// Once the function completes successfully, the key, if not [`forgotten`], is removed from + /// the group, allowing future calls with the same key to execute the function again. + /// + /// [`forgotten`]: Self::forget + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use std::sync::atomic::AtomicUsize; + /// use std::sync::atomic::Ordering; + /// use std::time::Duration; + /// + /// use mea::singleflight::Group; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let group = Group::new(); + /// + /// let fut1 = group.try_work("key", || async move { + /// // simulate heavy work to avoid immediate completion + /// tokio::time::sleep(Duration::from_millis(100)).await; + /// Err::<_, &'static str>("fut1") + /// }); + /// + /// let fut2 = group.try_work("key", || async move { + /// // simulate heavy work to avoid immediate completion + /// tokio::time::sleep(Duration::from_millis(200)).await; + /// Ok::<_, &'static str>("fut2") + /// }); + /// + /// let (r1, r2) = tokio::join!(fut1, fut2); + /// + /// assert_eq!(r1, Err("fut1")); + /// assert_eq!(r2, Ok("fut2")); + /// # } + /// ``` + pub async fn try_work(&self, key: K, func: F) -> Result + where + F: AsyncFnOnce() -> Result, + { + // 1. Get or create the OnceCell. + let cell = { + let mut map = self.map.lock(); + map.entry(key.clone()) + .or_insert_with(|| Arc::new(OnceCell::new())) + .clone() + }; + + // 2. Try to initialize the cell. + // OnceCell::get_or_try_init guarantees that only one task executes the closure. + let res = cell + .get_or_try_init(async || { + let result = func().await?; + + // Cleanup: remove the key from the map. + // We must ensure we remove the entry corresponding to *this* cell. + let mut map = self.map.lock(); + if let Some(existing) = map.get(&key) { + // Check if the map still points to our cell. + if Arc::ptr_eq(&cell, existing) { + map.remove(&key); + } + } + + Ok(result) + }) + .await?; + + Ok(res.clone()) + } + /// Forgets about the given key. /// /// Future calls to `work` for this key will call the function rather than waiting for an /// earlier call to complete. Existing calls to `work` for this key are not affected. - pub fn forget(&self, key: &K) { + pub fn forget(&self, key: &Q) + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { let mut map = self.map.lock(); map.remove(key); }